Skip to content

Commit

Permalink
Simplify the variable management and add new_lits convenience feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Dekker1 committed Oct 17, 2024
1 parent 0d5778c commit e8da0d4
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 295 deletions.
37 changes: 14 additions & 23 deletions crates/pindakaas-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
#[cfg(feature = "external-propagation")]
impl crate::solver::propagation::SolvingActions for #actions_ident {
fn new_var(&mut self) -> crate::Var {
self.vars.as_ref().lock().unwrap().next().expect("variable pool exhaused")
self.vars.as_ref().lock().unwrap().next_var()
}
fn add_observed_var(&mut self, var: crate::Var) {
unsafe { #krate::ipasir_add_observed_var( self.ptr, var.0.get()) };
Expand Down Expand Up @@ -332,34 +332,29 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
quote! {
fn new_var(&mut self) -> crate::Var {
#[cfg(feature = "external-propagation")]
let var = #vars .as_ref().lock().unwrap().next().expect("variable pool exhaused");
let var = #vars .as_ref().lock().unwrap().next_var();
#[cfg(not(feature = "external-propagation"))]
let var = #vars .next().expect("variable pool exhaused");
let var = #vars .next_var();
var
}
}
} else {
quote! {
fn new_var(&mut self) -> crate::Var {
#vars .next().expect("variable pool exhaused")
}
}
};

let next_var_range = if opts.ipasir_up {
quote! {
fn next_var_range(&mut self, size: usize) -> Option<crate::VarRange> {
fn new_var_range(&mut self, len: usize) -> crate::VarRange {
#[cfg(feature = "external-propagation")]
let r = #vars .as_ref().lock().unwrap().next_var_range(size);
let var = #vars .as_ref().lock().unwrap().next_var_range(len);
#[cfg(not(feature = "external-propagation"))]
let r = #vars .next_var_range(size);
r
let var = #vars .next_var_range(len);
var
}
}
} else {
quote! {
fn next_var_range(&mut self, size: usize) -> Option<crate::VarRange> {
#vars .next_var_range(size)
fn new_var(&mut self) -> crate::Var {
#vars .next_var()
}

fn new_var_range(&mut self, len: usize) -> crate::VarRange {
let var = #vars .next_var_range(len);
var
}
}
};
Expand Down Expand Up @@ -441,10 +436,6 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
}
}

impl crate::solver::NextVarRange for #ident {
#next_var_range
}

impl crate::solver::Solver for #ident {
type ValueFn = #sol_ident;

Expand Down
111 changes: 20 additions & 91 deletions crates/pindakaas/src/bool_linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2126,9 +2126,8 @@ mod tests {
cardinality::{tests::card_test_suite, Cardinality},
cardinality_one::{tests::card1_test_suite, CardinalityOne, PairwiseEncoder},
helpers::tests::{assert_checker, assert_encoding, assert_solutions, expect_file},
solver::NextVarRange,
sorted::SortedEncoder,
Cnf, Coeff, Encoder, Lit, Unsatisfiable,
ClauseDatabase, Cnf, Coeff, Encoder, Lit, Unsatisfiable,
};

pub(crate) fn construct_terms<L: Into<Lit> + Clone>(terms: &[(L, Coeff)]) -> Vec<Part> {
Expand All @@ -2141,12 +2140,7 @@ mod tests {
#[test]
fn test_aggregator_at_least_one_negated() {
let mut cnf = Cnf::default();
let (a, b, c, d) = cnf
.next_var_range(4)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c, d) = cnf.new_lits();
// Correctly detect that all but one literal can be set to true
assert_eq!(
BoolLinAggregator::default().aggregate(
Expand All @@ -2166,12 +2160,7 @@ mod tests {

// Correctly detect equal k
let mut cnf = Cnf::default();
let (a, b, c) = cnf
.next_var_range(3)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c) = cnf.new_lits();
assert_eq!(
BoolLinAggregator::default().aggregate(
&mut cnf,
Expand All @@ -2192,12 +2181,7 @@ mod tests {
#[test]
fn test_aggregator_combine() {
let mut cnf = Cnf::default();
let (a, b, c) = cnf
.next_var_range(3)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c) = cnf.new_lits();
// Simple aggregation of multiple occurrences of the same literal
assert_eq!(
BoolLinAggregator::default().aggregate(
Expand Down Expand Up @@ -2260,12 +2244,7 @@ mod tests {
#[test]
fn test_aggregator_detection() {
let mut cnf = Cnf::default();
let (a, b, c, d) = cnf
.next_var_range(4)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c, d) = cnf.new_lits();

// Correctly detect at most one
assert_eq!(
Expand Down Expand Up @@ -2412,7 +2391,7 @@ mod tests {
#[test]
fn test_aggregator_equal_one() {
let mut cnf = Cnf::default();
let vars = cnf.next_var_range(3).unwrap().iter_lits().collect_vec();
let vars = cnf.new_var_range(3).iter_lits().collect_vec();
// An exactly one constraint adds an exactly one constraint
assert_eq!(
BoolLinAggregator::default().aggregate(
Expand All @@ -2434,12 +2413,7 @@ mod tests {
#[test]
fn test_aggregator_false_trivial_unsat() {
let mut cnf = Cnf::default();
let (a, b, c, d, e, f, g) = cnf
.next_var_range(7)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c, d, e, f, g) = cnf.new_lits();
assert_eq!(
BoolLinAggregator::default().aggregate(
&mut cnf,
Expand Down Expand Up @@ -2469,12 +2443,7 @@ mod tests {
#[test]
fn test_aggregator_neg_coeff() {
let mut cnf = Cnf::default();
let (a, b, c) = cnf
.next_var_range(3)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c) = cnf.new_lits();

// Correctly convert a negative coefficient
assert_eq!(
Expand Down Expand Up @@ -2526,12 +2495,7 @@ mod tests {

// Correctly convert multiple negative coefficients with AMO constraints
let mut cnf = Cnf::default();
let (a, b, c, d, e, f) = cnf
.next_var_range(6)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c, d, e, f) = cnf.new_lits();
assert_eq!(
BoolLinAggregator::default().aggregate(
&mut cnf,
Expand Down Expand Up @@ -2563,12 +2527,7 @@ mod tests {

// Correctly convert multiple negative coefficients with side constraints
let mut cnf = Cnf::default();
let (a, b, c, d, e, f) = cnf
.next_var_range(6)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c, d, e, f) = cnf.new_lits();
assert_eq!(
BoolLinAggregator::default().aggregate(
&mut cnf,
Expand Down Expand Up @@ -2605,12 +2564,7 @@ mod tests {

// Correctly convert GreaterEq into LessEq with side constrains
let mut cnf = Cnf::default();
let (a, b, c, d, e, f) = cnf
.next_var_range(6)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c, d, e, f) = cnf.new_lits();
assert_eq!(
BoolLinAggregator::default().aggregate(
&mut cnf,
Expand Down Expand Up @@ -2642,12 +2596,7 @@ mod tests {

// Correctly convert GreaterEq into LessEq with side constrains
let mut cnf = Cnf::default();
let (a, b, c, d, e, f) = cnf
.next_var_range(6)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c, d, e, f) = cnf.new_lits();
assert_eq!(
BoolLinAggregator::default().aggregate(
&mut cnf,
Expand Down Expand Up @@ -2676,12 +2625,7 @@ mod tests {

// Correctly account for the coefficient in the Dom bounds
let mut cnf = Cnf::default();
let (a, b, c) = cnf
.next_var_range(3)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c) = cnf.new_lits();
assert_eq!(
BoolLinAggregator::default().aggregate(
&mut cnf,
Expand All @@ -2708,12 +2652,7 @@ mod tests {

// Correctly convert GreaterEq into LessEq with side constrains
let mut cnf = Cnf::default();
let (a, b, c, d, e) = cnf
.next_var_range(5)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c, d, e) = cnf.new_lits();
assert_eq!(
BoolLinAggregator::default().aggregate(
&mut cnf,
Expand Down Expand Up @@ -2751,12 +2690,7 @@ mod tests {
#[test]
fn test_aggregator_sort_same_coefficients() {
let mut cnf = Cnf::default();
let (a, b, c, d) = cnf
.next_var_range(4)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c, d) = cnf.new_lits();

assert_eq!(
BoolLinAggregator::default()
Expand Down Expand Up @@ -2787,7 +2721,7 @@ mod tests {
#[test]
fn test_aggregator_sort_same_coefficients_using_minimal_chain() {
let mut cnf = Cnf::default();
let vars = cnf.next_var_range(5).unwrap().iter_lits().collect_vec();
let vars = cnf.new_var_range(5).iter_lits().collect_vec();
assert_eq!(
BoolLinAggregator::default()
.sort_same_coefficients(SortedEncoder::default(), 2)
Expand Down Expand Up @@ -2816,7 +2750,7 @@ mod tests {
#[test]
fn test_aggregator_unsat() {
let mut db = Cnf::default();
let vars = db.next_var_range(3).unwrap().iter_lits().collect_vec();
let vars = db.new_var_range(3).iter_lits().collect_vec();

// Constant cannot be reached
assert_eq!(
Expand Down Expand Up @@ -2870,12 +2804,7 @@ mod tests {
#[test]
fn test_encoders() {
let mut cnf = Cnf::default();
let (a, b, c, d) = cnf
.next_var_range(4)
.unwrap()
.iter_lits()
.collect_tuple()
.unwrap();
let (a, b, c, d) = cnf.new_lits();
// TODO encode this if encoder does not support constraint
PairwiseEncoder::default()
.encode(
Expand Down Expand Up @@ -2919,7 +2848,7 @@ mod tests {
#[test]
fn test_pb_encode() {
let mut cnf = Cnf::default();
let vars = cnf.next_var_range(4).unwrap().iter_lits().collect_vec();
let vars = cnf.new_var_range(4).iter_lits().collect_vec();
LinearEncoder::<StaticLinEncoder>::default()
.encode(
&mut cnf,
Expand All @@ -2938,7 +2867,7 @@ mod tests {
#[test]
fn test_sort_same_coefficients_2() {
let mut db = Cnf::default();
let vars = db.next_var_range(5).unwrap().iter_lits().collect_vec();
let vars = db.new_var_range(5).iter_lits().collect_vec();
let mut agg = BoolLinAggregator::default();
let _ = agg.sort_same_coefficients(SortedEncoder::default(), 3);
let mut encoder = LinearEncoder::<StaticLinEncoder<TotalizerEncoder>>::default();
Expand Down
14 changes: 7 additions & 7 deletions crates/pindakaas/src/cardinality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ pub(crate) mod tests {
#[test]
fn test_card_le_2_3() {
let mut cnf = Cnf::default();
let vars = cnf.next_var_range(3).unwrap().iter_lits().collect_vec();
let vars = cnf.new_var_range(3).iter_lits().collect_vec();
$encoder
.encode(
&mut cnf,
Expand All @@ -129,7 +129,7 @@ pub(crate) mod tests {
#[test]
fn test_card_eq_1_3() {
let mut cnf = Cnf::default();
let vars = cnf.next_var_range(3).unwrap().iter_lits().collect_vec();
let vars = cnf.new_var_range(3).iter_lits().collect_vec();
$encoder
.encode(
&mut cnf,
Expand All @@ -151,7 +151,7 @@ pub(crate) mod tests {
#[test]
fn test_card_eq_2_3() {
let mut cnf = Cnf::default();
let vars = cnf.next_var_range(3).unwrap().iter_lits().collect_vec();
let vars = cnf.new_var_range(3).iter_lits().collect_vec();
$encoder
.encode(
&mut cnf,
Expand All @@ -173,7 +173,7 @@ pub(crate) mod tests {
#[test]
fn test_card_eq_2_4() {
let mut cnf = Cnf::default();
let vars = cnf.next_var_range(4).unwrap().iter_lits().collect_vec();
let vars = cnf.new_var_range(4).iter_lits().collect_vec();
$encoder
.encode(
&mut cnf,
Expand All @@ -195,7 +195,7 @@ pub(crate) mod tests {
#[test]
fn test_card_eq_3_5() {
let mut cnf = Cnf::default();
let vars = cnf.next_var_range(5).unwrap().iter_lits().collect_vec();
let vars = cnf.new_var_range(5).iter_lits().collect_vec();
$encoder
.encode(
&mut cnf,
Expand Down Expand Up @@ -226,7 +226,7 @@ pub(crate) mod tests {
cardinality::{Cardinality, SortingNetworkEncoder},
helpers::tests::assert_solutions,
sorted::{SortedEncoder, SortedStrategy},
Cnf, Encoder, NextVarRange,
ClauseDatabase, Cnf, Encoder,
};

#[test]
Expand Down Expand Up @@ -289,7 +289,7 @@ pub(crate) mod tests {
macro_rules! test_card {
($encoder:expr,$n:expr,$cmp:expr,$k:expr) => {
let mut cnf = Cnf::default();
let vars = cnf.next_var_range($n).unwrap().iter_lits().collect_vec();
let vars = cnf.new_var_range($n).iter_lits().collect_vec();
$encoder
.encode(
&mut cnf,
Expand Down
Loading

0 comments on commit e8da0d4

Please sign in to comment.