From b6bf96df764b84917318bb4c4aa4ae93e2073c99 Mon Sep 17 00:00:00 2001 From: Ishir Garg Date: Wed, 10 Apr 2024 13:05:54 -0700 Subject: [PATCH 01/16] Added untested puzzle and cyclic solver implementations --- src/solver/algorithm/strong/cyclic.rs | 330 ++++++++++++++++++++++++++ src/solver/algorithm/strong/puzzle.rs | 244 +++++++++++++++---- src/solver/mod.rs | 2 + 3 files changed, 536 insertions(+), 40 deletions(-) diff --git a/src/solver/algorithm/strong/cyclic.rs b/src/solver/algorithm/strong/cyclic.rs index 0950bd1..0e5c3c4 100644 --- a/src/solver/algorithm/strong/cyclic.rs +++ b/src/solver/algorithm/strong/cyclic.rs @@ -7,3 +7,333 @@ //! #### Authorship //! //! - Max Fierro, 12/3/2023 (maxfierro@berkeley.edu) +//! - Ishir Garg, 3/12/2024 (ishirgarg@berkeley.edu) + +use std::collections::{VecDeque, HashMap}; +use anyhow::{Context, Result}; + +use crate::database::volatile; +use crate::database::{KVStore, Tabular}; +use crate::game::{Bounded, DTransition, Extensive, SimpleSum}; +use crate::solver::SimpleUtility; +use crate::interface::IOMode; +use crate::model::{PlayerCount, Remoteness, State, Turn}; +use crate::solver::record::sur::RecordBuffer; +use crate::solver::RecordType; + +/* CONSTANTS */ + +/// The exact number of bits that are used to encode remoteness. +const REMOTENESS_SIZE: usize = 16; + +/// The maximum number of bits that can be used to encode a record. +const BUFFER_SIZE: usize = 128; + +/// The exact number of bits that are used to encode utility for one player. +const UTILITY_SIZE: usize = 2; + + +pub fn two_player_zero_sum_dynamic_solver(game: &G, mode: IOMode) -> Result<()> +where + G: DTransition + Bounded + SimpleSum<2> +{ + let mut db = volatile_database(game) + .context("Failed to initialize database.")?; + basic_loopy_solver(game, &mut db)?; + Ok(()) +} + +fn basic_loopy_solver (game: &G, db: &mut D) -> Result<()> +where + G: DTransition + Bounded + SimpleSum<2>, + D: KVStore, +{ + let mut winning_frontier = VecDeque::new(); + let mut tying_frontier = VecDeque::new(); + let mut losing_frontier = VecDeque::new(); + + let mut child_counts = HashMap::new(); + + enqueue_children(&mut winning_frontier, &mut tying_frontier, &mut losing_frontier, game.start(), game, &mut child_counts, db)?; + + // Process winning and losing frontiers + while !winning_frontier.is_empty() && !losing_frontier.is_empty() && !tying_frontier.is_empty() { + let child = if !losing_frontier.is_empty() { + losing_frontier.pop_front().unwrap() + } else if !winning_frontier.is_empty() { + winning_frontier.pop_front().unwrap() + } else { + tying_frontier.pop_front().unwrap() + }; + + let db_entry = RecordBuffer::from(db.get(child).unwrap()) + .context("Failed to create record for middle state.")?; + let child_utility = db_entry + .get_utility(game.turn(child)) + .context("Failed to get utility from record.")?; + let child_remoteness = db_entry + .get_remoteness(); + + + let parents = game.retrograde(child); + // If child is a losing position + if matches!(child_utility, SimpleUtility::LOSE) { + for parent in parents { + if *child_counts.get(&parent).expect("Failed to enqueue parent state in initial enqueueing stage") > 0 { + // Add database entry + let mut buf = RecordBuffer::new(game.players()) + .context("Failed to create record for end state.")?; + buf.set_utility([SimpleUtility::WIN, SimpleUtility::LOSE])?; + buf.set_remoteness(child_remoteness + 1)?; + db.put(parent, &buf); + + // Update child counts + child_counts.insert(parent, 0); + + // Add parent to win frontier + winning_frontier.push_back(parent); + } + } + } + // If child is a winning position + else if matches!(child_utility, SimpleUtility::WIN) { + for parent in parents { + let child_count = *child_counts.get(&parent).expect("Failed to enqueue parent state in initial enqueuing stage"); + // Parent has already been solved + if child_count == 0 { + continue; + } + // This is the last child left to process + if child_count == 1 { + // Add database entry + let mut buf = RecordBuffer::new(game.players()) + .context("Failed to create record for end state.")?; + buf.set_utility([SimpleUtility::LOSE, SimpleUtility::WIN])?; + buf.set_remoteness(child_remoteness + 1)?; + db.put(parent, &buf); + + // Add parent to win frontier + losing_frontier.push_back(parent); + } + // Update child count + child_counts.insert(parent, child_count - 1); + } + } + // Child should never be a tying position + else if matches!(child_utility, SimpleUtility::TIE) { + for parent in parents { + let child_count = *child_counts.get(&parent).expect("Failed to enqueue parent state in initial enqueuing stage"); + // Parent has already been solved + if child_count == 0 { + continue; + } + // Add database entry + let mut buf = RecordBuffer::new(game.players()) + .context("Failed to create record for end state.")?; + buf.set_utility([SimpleUtility::TIE, SimpleUtility::TIE])?; + buf.set_remoteness(child_remoteness + 1)?; + db.put(parent, &buf); + + // Add parent to win frontier + tying_frontier.push_back(parent); + // Update child count + child_counts.insert(parent, 0); + } + + } + else { + panic!["Position with invalid utility found in frontiers"]; + } + } + + // Assign drawing utility + for (parent, child_count) in child_counts { + if child_count > 0 { + let mut buf = RecordBuffer::new(game.players()) + .context("Failed to create record for end state.")?; + buf.set_utility([SimpleUtility::DRAW, SimpleUtility::DRAW])?; + db.put(parent, &buf); + } + } + + Ok(()) +} + +/// Set up the initial frontiers and primitive position database entries +fn enqueue_children(winning_frontier: &mut VecDeque, + tying_frontier: &mut VecDeque, + losing_frontier: &mut VecDeque, + curr_state: State, + game: &G, + child_counts: &mut HashMap, + db: &mut D +) -> Result<()> +where + G: DTransition + Bounded + SimpleSum<2>, + D: KVStore, +{ + if game.end(curr_state) { + let mut buf = RecordBuffer::new(game.players()) + .context("Failed to create placeholder record.")?; + buf.set_utility(game.utility(curr_state)) + .context("Failed to copy utility values to record.")?; + buf.set_remoteness(0) + .context("Failed to set remoteness for end state.")?; + db.put(curr_state, &buf); + + match game.utility(curr_state).get(game.turn(curr_state)) { + Some(SimpleUtility::WIN) => winning_frontier.push_back(curr_state), + Some(SimpleUtility::TIE) => tying_frontier.push_back(curr_state), + Some(SimpleUtility::LOSE) => losing_frontier.push_back(curr_state), + _ => panic!["Utility for primitive ending position found to be draw"] + } + return Ok(()); + } + + // Enqueue primitive positions into frontiers + let children = game.prograde(curr_state); + child_counts.insert(curr_state, children.len()); + + for child in children { + if child_counts.contains_key(&child) { + continue; + } + enqueue_children(winning_frontier, tying_frontier, losing_frontier, child, game, child_counts, db)?; + } + + Ok(()) +} + + + +/* DATABASE INITIALIZATION */ + +/// Initializes a volatile database, creating a table schema according to the +/// solver record layout, initializing a table with that schema, and switching +/// to that table before returning the database handle. +fn volatile_database(game: &G) -> Result +where + G: Extensive<2>, +{ + let id = game.id(); + let db = volatile::Database::initialize(); + + let schema = RecordType::SUR(2) + .try_into() + .context("Failed to create table schema for solver records.")?; + db.create_table(&id, schema) + .context("Failed to create database table for solution set.")?; + db.select_table(&id) + .context("Failed to select solution set database table.")?; + + Ok(db) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use crate::game::{Game, GameData, Bounded, DTransition, Extensive, SimpleSum}; + use crate::model::{State, Turn}; + use std::collections::{HashMap, VecDeque}; + use crate::interface::{IOMode, SolutionMode}; + use crate::solver::SimpleUtility; + + use super::{enqueue_children, volatile_database}; + + struct GameNode { + // state: State, + turn: Turn, + utility: Vec, + children: Vec + } + + struct GameGraph { + num_states: u32, + adj_list: Vec, + } + + impl Game for GameGraph { + fn initialize(variant: Option) -> Result + where + Self: Sized + { + unimplemented!(); + } + + fn forward(&mut self, history: Vec) -> Result<()> { + unimplemented!(); + } + + fn id(&self) -> String { + String::from("GameGraph") + } + + fn info(&self) -> GameData { + unimplemented!(); + } + + fn solve(&self, mode: IOMode, method: SolutionMode) -> Result<()> { + unimplemented!(); + } + } + + impl Bounded for GameGraph { + fn start(&self) -> u64 { + 0 + } + + fn end(&self, state: State) -> bool { + self.adj_list[state as usize].children.is_empty() + } + } + + impl Extensive<2> for GameGraph { + fn turn(&self, state: State) -> Turn { + self.adj_list[state as usize].turn + } + } + + impl SimpleSum<2> for GameGraph { + fn utility(&self, state: State) -> [SimpleUtility; 2] { + let util = &self.adj_list[state as usize].utility; + [util[0], util[1]] + } + } + + impl DTransition for GameGraph { + fn prograde(&self, state: State) -> Vec { + self.adj_list[state as usize].children.clone() + } + + fn retrograde(&self, state: State) -> Vec { + todo![]; + } + } + + #[test] + fn enqueues_children_properly() -> Result<()>{ + let graph = GameGraph { + num_states: 2, + adj_list: vec![ + GameNode {turn: 0, utility: vec![], children: vec![1]}, + GameNode {turn: 1, utility: vec![SimpleUtility::LOSE, SimpleUtility::WIN], children: vec![]} + ] + }; + + let mut db = volatile_database(&graph)?; + + let mut winning_frontier = VecDeque::new(); + let mut tying_frontier = VecDeque::new(); + let mut losing_frontier = VecDeque::new(); + let mut child_counts = HashMap::new(); + + enqueue_children(&mut winning_frontier, &mut tying_frontier, &mut losing_frontier, graph.start(), &graph, &mut child_counts, &mut db)?; + + assert!(winning_frontier.is_empty()); + assert!(tying_frontier.is_empty()); + assert_eq!(losing_frontier, vec![1]); + assert_eq!(child_counts, HashMap::from([(0, 1), (1, 0)])); + + Ok(()) + } +} diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index d969b1b..5b48ea1 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -7,81 +7,149 @@ use anyhow::{Context, Result}; +use std::collections::{HashSet, VecDeque, HashMap}; use crate::database::volatile; use crate::database::{KVStore, Tabular}; -use crate::game::{Bounded, DTransition, GeneralSum, Playable, STransition}; +use crate::game::{Bounded, DTransition, ClassicPuzzle, Extensive}; use crate::interface::IOMode; -use crate::model::{PlayerCount, Remoteness, State, Utility}; -use crate::solver::record::mur::RecordBuffer; -use crate::solver::{RecordType, MAX_TRANSITIONS}; +use crate::model::{Remoteness, State}; +use crate::solver::record::sur::RecordBuffer; +use crate::solver::RecordType; +use crate::solver::SimpleUtility; +use crate::solver::error::SolverError::SolverViolation; pub fn dynamic_solver(game: &G, mode: IOMode) -> Result<()> where - G: DTransition + Bounded + Playable + GeneralSum, + G: DTransition + Bounded + ClassicPuzzle, { let mut db = volatile_database(game) .context("Failed to initialize volatile database.")?; - bfs(&mut db, game) + reverse_bfs(&mut db, game) .context("Failed solving algorithm execution.")?; Ok(()) } -fn bfs(game: &G, db: &mut D) +fn reverse_bfs(db: &mut D, game: &G) -> Result<()> where - G: DTransition + Bounded + SimpleSum, + G: DTransition + Bounded + ClassicPuzzle, D: KVStore, { - let end_states = discover_end_states_helper(db, game); + // Get end states and create frontiers + let mut child_counts = discover_child_counts(db, game); + let end_states = child_counts.iter().filter(|&x| *x.1 == 0).map(|x| *x.0); - for state in end_states { - let mut buf = RecordBuffer::new() + let mut winning_queue: VecDeque<(State, Remoteness)> = VecDeque::new(); + let mut losing_queue: VecDeque<(State, Remoteness)> = VecDeque::new(); + for end_state in end_states { + match ClassicPuzzle::utility(game, end_state) { + SimpleUtility::WIN => winning_queue.push_back((end_state, 0)), + SimpleUtility::LOSE => losing_queue.push_back((end_state, 0)), + SimpleUtility::TIE => Err(SolverViolation { + name: "PuzzleSolver".to_string(), + hint: format!("Primitive end position cannot have utility TIE for a puzzle"), + })?, + SimpleUtility::DRAW => Err(SolverViolation { + name: "PuzzleSolver".to_string(), + hint: format!("Primitive end position cannot have utility DRAW for a puzzle"), + })?, + } + } + + // Contains states that have already been visited + let mut visited = HashSet::new(); + + // Perform BFS on winning states + while let Some((state, remoteness)) = winning_queue.pop_front() { + + let mut buf = RecordBuffer::new(1) .context("Failed to create placeholder record.")?; - buf.set_remoteness(0) - .context("Failed to set remoteness for end state.")?; + buf.set_utility([SimpleUtility::WIN]) + .context("Failed to set remoteness for state.")?; + buf.set_remoteness(remoteness) + .context("Failed to set remoteness for state.")?; db.put(state, &buf); - bfs_state(db, game, state); + child_counts.insert(state, 0); + visited.insert(state); + let parents = game.retrograde(state); + + for parent in parents { + if !visited.contains(&parent) { + winning_queue.push_back((parent, remoteness + 1)); + } + } } -} -fn bfs_state(db: &mut D, game: &G) -where - G: DTransition + Bounded + SimpleSum, - D: KVStore, -{ + // Perform BFS on losing states + while let Some((state, remoteness)) = losing_queue.pop_front() { + let mut buf = RecordBuffer::new(1) + .context("Failed to create placeholder record.")?; + buf.set_utility([SimpleUtility::LOSE]) + .context("Failed to set remoteness for state.")?; + buf.set_remoteness(remoteness) + .context("Failed to set remoteness for state.")?; + db.put(state, &buf); + + visited.insert(state); + let parents = game.retrograde(state); + + for parent in parents { + // The check below is needed, because it is theoretically possible for child_counts to + // NOT contain a position discovered by retrograde(). Consider a 3-node game tree with + // starting vertex 1, and edges (1 -> 2), (3 -> 2), where 2 is a losing primitive ending position. + // In this case, running discover_child_counts() on 1 above only gets child_counts for states 1 and 2, + // however calling retrograde on end state 2 in this BFS portion will discover state 2 + // for the first time. + match child_counts.get(&parent) { + Some(count) => child_counts.insert(parent, count - 1), + None => child_counts.insert(parent, game.prograde(parent).len() - 1), + }; + + if !visited.contains(&parent) && *child_counts.get(&state).unwrap() == 0 { + losing_queue.push_back((parent, remoteness + 1)); + } + } + } + + // Get remaining draw positions + for (state, count) in child_counts { + if count > 0 { + let mut buf = RecordBuffer::new(1) + .context("Failed to create placeholder record.")?; + buf.set_utility([SimpleUtility::DRAW]) + .context("Failed to set remoteness for state.")?; + db.put(state, &buf); + } + } + Ok(()) } -fn discover_end_states(db: &mut D, game: &G) -> Vec +fn discover_child_counts(db: &mut D, game: &G) -> HashMap where - G: DTransition + Bounded + SimpleSum, + G: DTransition + Bounded + ClassicPuzzle, D: KVStore, { - let visited = HashSet::new(); - let end_states = Vec::new(); + let mut child_counts = HashMap::new(); - discover_end_states(db, game, game.start(), visited, end_states); + discover_child_counts_helper(db, game, game.start(), &mut child_counts); - end_states + child_counts } -fn discover_end_states_helper(db: &mut D, game: &G, state: State, visited: HashSet, end_states: Vec) +fn discover_child_counts_helper(db: &mut D, game: &G, state: State, child_counts: &mut HashMap) where - G: DTransition + Bounded + SimpleSum, + G: DTransition + Bounded + ClassicPuzzle, D: KVStore, { - visited.insert(state); - - if game.end(state) { - end_states.insert(state); - } + child_counts.insert(state, game.prograde(state).len()); for child in game.prograde(state) { - if !visted.contains(child) { - discover_end_states(db, game, child, visited, end_states); + if !child_counts.contains_key(&child) { + discover_child_counts_helper(db, game, child, child_counts); } } } @@ -92,12 +160,12 @@ where /// to that table before returning the database handle. fn volatile_database(game: &G) -> Result where - G: Playable, + G: Extensive, { let id = game.id(); let db = volatile::Database::initialize(); - let schema = RecordType::REMOTE(N) + let schema = RecordType::SUR(1) .try_into() .context("Failed to create table schema for solver records.")?; db.create_table(&id, schema) @@ -110,9 +178,105 @@ where #[cfg(test)] -mod test { +mod tests { + use anyhow::Result; + use crate::game::{Game, GameData, Bounded, DTransition, Extensive, ClassicPuzzle, SimpleSum}; + use crate::model::{State, Turn}; + use std::collections::{HashMap, VecDeque}; + use crate::interface::{IOMode, SolutionMode}; + use crate::solver::SimpleUtility; + + use super::{discover_child_counts, volatile_database}; + + struct GameNode { + utility: Option, // Is None for non-primitive puzzle nodes + children: Vec + } + + struct PuzzleGraph { + adj_list: Vec, + } + + impl PuzzleGraph { + fn size(&self) -> u64 { + self.adj_list.len() as u64 + } + } + + impl Game for PuzzleGraph { + fn initialize(variant: Option) -> Result + where + Self: Sized + { + unimplemented!(); + } + + fn forward(&mut self, history: Vec) -> Result<()> { + unimplemented!(); + } + + fn id(&self) -> String { + String::from("GameGraph") + } + + fn info(&self) -> GameData { + unimplemented!(); + } + + fn solve(&self, mode: IOMode, method: SolutionMode) -> Result<()> { + unimplemented!(); + } + } + + impl Bounded for PuzzleGraph { + fn start(&self) -> u64 { + 0 + } + + fn end(&self, state: State) -> bool { + self.adj_list[state as usize].children.is_empty() + } + } + + impl Extensive<1> for PuzzleGraph { + fn turn(&self, state: State) -> Turn { + 0 + } + } + + impl SimpleSum<1> for PuzzleGraph { + fn utility(&self, state: State) -> [SimpleUtility; 1] { + [self.adj_list[state as usize].utility.unwrap()] + } + } + + impl ClassicPuzzle for PuzzleGraph {} + + impl DTransition for PuzzleGraph { + fn prograde(&self, state: State) -> Vec { + self.adj_list[state as usize].children.clone() + } + + fn retrograde(&self, state: State) -> Vec { + (0..self.size()).filter(|&s| self.adj_list[s as usize].children.contains(&state)).collect() + } + } + #[test] - fn test() { - assert!(false); + fn gets_child_counts_correctly() -> Result<()>{ + let graph = PuzzleGraph { + adj_list: vec![ + GameNode {utility: None, children: vec![1]}, + GameNode {utility: Some(SimpleUtility::LOSE), children: vec![]} + ] + }; + + let mut db = volatile_database(&graph)?; + + let child_counts = discover_child_counts(&mut db, &graph); + + assert_eq!(child_counts, HashMap::from([(0, 1), (1, 0)])); + + Ok(()) } } diff --git a/src/solver/mod.rs b/src/solver/mod.rs index f64e039..d3c9b55 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -16,6 +16,7 @@ use crate::model::PlayerCount; use crate::solver::error::SolverError::RecordViolation; use anyhow::Result; +use std::fmt::Display; /// Describes the maximum number of states that are one move away from any state /// within a game. Used to allocate statically-sized arrays on the stack for @@ -97,6 +98,7 @@ pub mod algorithm { pub mod strong { pub mod acyclic; pub mod cyclic; + pub mod puzzle; } /// Solving algorithms for deterministic complete-information games that From 6573e2cb76c657dd87c811a6d100e9e677cf2a0c Mon Sep 17 00:00:00 2001 From: Max Fierro Date: Thu, 11 Apr 2024 02:44:32 -0700 Subject: [PATCH 02/16] Added blanket impl for ClassicGame --- src/game/mod.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/game/mod.rs b/src/game/mod.rs index 74c0770..6a89a4a 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -354,6 +354,15 @@ where } } +impl SimpleSum<2> for G +where + G: ClassicGame, +{ + fn utility(&self, state: State) -> [SimpleUtility; 2] { + todo!() + } +} + impl SimpleSum<1> for G where G: ClassicPuzzle, From cee7a44ddb872c61c50e00ef05de6814a24e218b Mon Sep 17 00:00:00 2001 From: Ishir Garg Date: Thu, 11 Apr 2024 04:24:16 -0700 Subject: [PATCH 03/16] updated puzzle and cyclic solvers to accomodate new game api structure --- src/solver/algorithm/strong/cyclic.rs | 25 ++++++------- src/solver/algorithm/strong/puzzle.rs | 52 +++++++++++++-------------- 2 files changed, 36 insertions(+), 41 deletions(-) diff --git a/src/solver/algorithm/strong/cyclic.rs b/src/solver/algorithm/strong/cyclic.rs index 73e7046..afacaa4 100644 --- a/src/solver/algorithm/strong/cyclic.rs +++ b/src/solver/algorithm/strong/cyclic.rs @@ -13,8 +13,8 @@ use anyhow::{Context, Result}; use crate::database::volatile; use crate::database::{KVStore, Tabular}; -use crate::game::{Bounded, DTransition, Extensive, SimpleSum}; -use crate::solver::SimpleUtility; +use crate::game::{Bounded, DTransition, Extensive, SimpleSum, Game}; +use crate::model::SimpleUtility; use crate::interface::IOMode; use crate::model::{PlayerCount, Remoteness, State, Turn}; use crate::solver::record::sur::RecordBuffer; @@ -34,7 +34,7 @@ const UTILITY_SIZE: usize = 2; pub fn two_player_zero_sum_dynamic_solver(game: &G, mode: IOMode) -> Result<()> where - G: DTransition + Bounded + SimpleSum<2> + G: DTransition + Bounded + SimpleSum<2> + Extensive<2> + Game { let mut db = volatile_database(game) .context("Failed to initialize database.")?; @@ -44,8 +44,8 @@ where fn basic_loopy_solver (game: &G, db: &mut D) -> Result<()> where - G: DTransition + Bounded + SimpleSum<2>, - D: KVStore, + G: DTransition + Bounded + SimpleSum<2> + Extensive<2> + Game, + D: KVStore, { let mut winning_frontier = VecDeque::new(); let mut tying_frontier = VecDeque::new(); @@ -168,8 +168,8 @@ fn enqueue_children(winning_frontier: &mut VecDeque, db: &mut D ) -> Result<()> where - G: DTransition + Bounded + SimpleSum<2>, - D: KVStore, + G: DTransition + Bounded + SimpleSum<2> + Extensive<2> + Game, + D: KVStore, { if game.end(curr_state) { let mut buf = RecordBuffer::new(game.players()) @@ -212,7 +212,7 @@ where /// to that table before returning the database handle. fn volatile_database(game: &G) -> Result where - G: Extensive<2>, + G: Extensive<2> + Game, { let id = game.id(); let db = volatile::Database::initialize(); @@ -235,12 +235,11 @@ mod tests { use crate::model::{State, Turn}; use std::collections::{HashMap, VecDeque}; use crate::interface::{IOMode, SolutionMode}; - use crate::solver::SimpleUtility; + use crate::model::SimpleUtility; use super::{enqueue_children, volatile_database}; struct GameNode { - // state: State, turn: Turn, utility: Vec, children: Vec @@ -252,17 +251,13 @@ mod tests { } impl Game for GameGraph { - fn initialize(variant: Option) -> Result + fn new(variant: Option) -> Result where Self: Sized { unimplemented!(); } - fn forward(&mut self, history: Vec) -> Result<()> { - unimplemented!(); - } - fn id(&self) -> String { String::from("GameGraph") } diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index 5b48ea1..96bef9f 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -10,17 +10,17 @@ use anyhow::{Context, Result}; use std::collections::{HashSet, VecDeque, HashMap}; use crate::database::volatile; use crate::database::{KVStore, Tabular}; -use crate::game::{Bounded, DTransition, ClassicPuzzle, Extensive}; +use crate::game::{Game, Bounded, DTransition, ClassicPuzzle, Extensive}; use crate::interface::IOMode; use crate::model::{Remoteness, State}; use crate::solver::record::sur::RecordBuffer; use crate::solver::RecordType; -use crate::solver::SimpleUtility; +use crate::model::SimpleUtility; use crate::solver::error::SolverError::SolverViolation; pub fn dynamic_solver(game: &G, mode: IOMode) -> Result<()> where - G: DTransition + Bounded + ClassicPuzzle, + G: DTransition + Bounded + ClassicPuzzle + Extensive<1> + Game, { let mut db = volatile_database(game) .context("Failed to initialize volatile database.")?; @@ -34,8 +34,8 @@ where fn reverse_bfs(db: &mut D, game: &G) -> Result<()> where - G: DTransition + Bounded + ClassicPuzzle, - D: KVStore, + G: DTransition + Bounded + ClassicPuzzle + Extensive<1> + Game, + D: KVStore, { // Get end states and create frontiers let mut child_counts = discover_child_counts(db, game); @@ -97,12 +97,14 @@ where let parents = game.retrograde(state); for parent in parents { - // The check below is needed, because it is theoretically possible for child_counts to - // NOT contain a position discovered by retrograde(). Consider a 3-node game tree with - // starting vertex 1, and edges (1 -> 2), (3 -> 2), where 2 is a losing primitive ending position. - // In this case, running discover_child_counts() on 1 above only gets child_counts for states 1 and 2, - // however calling retrograde on end state 2 in this BFS portion will discover state 2 - // for the first time. + // The check below is needed, because it is theoretically possible + // for child_counts to NOT contain a position discovered by + // retrograde(). Consider a 3-node game tree with starting vertex 1, + // and edges (1 -> 2), (3 -> 2), where 2 is a losing primitive + // ending position. In this case, running discover_child_counts() on + // 1 above only gets child_counts for states 1 and 2, however + // calling retrograde on end state 2 in this BFS portion will + // discover state 2 for the first time. match child_counts.get(&parent) { Some(count) => child_counts.insert(parent, count - 1), None => child_counts.insert(parent, game.prograde(parent).len() - 1), @@ -130,8 +132,8 @@ where fn discover_child_counts(db: &mut D, game: &G) -> HashMap where - G: DTransition + Bounded + ClassicPuzzle, - D: KVStore, + G: DTransition + Bounded + ClassicPuzzle + Extensive<1> + Game, + D: KVStore, { let mut child_counts = HashMap::new(); @@ -143,7 +145,7 @@ where fn discover_child_counts_helper(db: &mut D, game: &G, state: State, child_counts: &mut HashMap) where G: DTransition + Bounded + ClassicPuzzle, - D: KVStore, + D: KVStore, { child_counts.insert(state, game.prograde(state).len()); @@ -153,6 +155,7 @@ where } } } + /* DATABASE INITIALIZATION */ /// Initializes a volatile database, creating a table schema according to the @@ -160,7 +163,7 @@ where /// to that table before returning the database handle. fn volatile_database(game: &G) -> Result where - G: Extensive, + G: Extensive + Game, { let id = game.id(); let db = volatile::Database::initialize(); @@ -174,6 +177,9 @@ where .context("Failed to select solution set database table.")?; Ok(db) + + // This is only for testing purposes + } @@ -184,7 +190,7 @@ mod tests { use crate::model::{State, Turn}; use std::collections::{HashMap, VecDeque}; use crate::interface::{IOMode, SolutionMode}; - use crate::solver::SimpleUtility; + use crate::model::SimpleUtility; use super::{discover_child_counts, volatile_database}; @@ -204,17 +210,13 @@ mod tests { } impl Game for PuzzleGraph { - fn initialize(variant: Option) -> Result + fn new(variant: Option) -> Result where Self: Sized { unimplemented!(); } - fn forward(&mut self, history: Vec) -> Result<()> { - unimplemented!(); - } - fn id(&self) -> String { String::from("GameGraph") } @@ -244,14 +246,12 @@ mod tests { } } - impl SimpleSum<1> for PuzzleGraph { - fn utility(&self, state: State) -> [SimpleUtility; 1] { - [self.adj_list[state as usize].utility.unwrap()] + impl ClassicPuzzle for PuzzleGraph { + fn utility(&self, state: State) -> SimpleUtility { + self.adj_list[state as usize].utility.unwrap() } } - impl ClassicPuzzle for PuzzleGraph {} - impl DTransition for PuzzleGraph { fn prograde(&self, state: State) -> Vec { self.adj_list[state as usize].children.clone() From b9d31bafb610d0e9da30cc2e44239f6c6d9566a3 Mon Sep 17 00:00:00 2001 From: Ishir Garg Date: Fri, 12 Apr 2024 04:27:16 -0700 Subject: [PATCH 04/16] Added blanket implementations for utility traits --- src/game/mod.rs | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/game/mod.rs b/src/game/mod.rs index 6a89a4a..97be225 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -350,16 +350,29 @@ where G: SimpleSum, { fn utility(&self, state: State) -> [Utility; N] { - todo!() + SimpleSum::utility(self, state).map(|x| x as Utility) } } impl SimpleSum<2> for G where - G: ClassicGame, + G: ClassicGame + Extensive<2>, { fn utility(&self, state: State) -> [SimpleUtility; 2] { - todo!() + let player_utility = ClassicGame::utility(self, state); + let other_player_utility = match player_utility { + SimpleUtility::WIN => SimpleUtility::LOSE, + SimpleUtility::LOSE => SimpleUtility::WIN, + SimpleUtility::TIE => SimpleUtility::TIE, + SimpleUtility::DRAW => SimpleUtility::DRAW, + }; + + if Extensive::turn(self, state) == 0 { + [player_utility, other_player_utility] + } + else { + [other_player_utility, player_utility] + } } } @@ -368,6 +381,6 @@ where G: ClassicPuzzle, { fn utility(&self, state: State) -> [SimpleUtility; 1] { - todo!() + [ClassicPuzzle::utility(self, state)] } } From 621d02bbdb1f52b04aa09407d2cff183d65d4697 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 11:27:51 +0000 Subject: [PATCH 05/16] Format Rust code using rustfmt --- src/game/mod.rs | 9 +- src/solver/algorithm/strong/cyclic.rs | 163 +++++++++++++++++--------- src/solver/algorithm/strong/puzzle.rs | 106 +++++++++++------ 3 files changed, 184 insertions(+), 94 deletions(-) diff --git a/src/game/mod.rs b/src/game/mod.rs index 97be225..9aeaa89 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -350,7 +350,7 @@ where G: SimpleSum, { fn utility(&self, state: State) -> [Utility; N] { - SimpleSum::utility(self, state).map(|x| x as Utility) + SimpleSum::utility(self, state).map(|x| x as Utility) } } @@ -368,10 +368,9 @@ where }; if Extensive::turn(self, state) == 0 { - [player_utility, other_player_utility] - } - else { - [other_player_utility, player_utility] + [player_utility, other_player_utility] + } else { + [other_player_utility, player_utility] } } } diff --git a/src/solver/algorithm/strong/cyclic.rs b/src/solver/algorithm/strong/cyclic.rs index afacaa4..0b2a07f 100644 --- a/src/solver/algorithm/strong/cyclic.rs +++ b/src/solver/algorithm/strong/cyclic.rs @@ -8,14 +8,14 @@ //! - Max Fierro, 12/3/2023 (maxfierro@berkeley.edu) //! - Ishir Garg, 3/12/2024 (ishirgarg@berkeley.edu) -use std::collections::{VecDeque, HashMap}; use anyhow::{Context, Result}; +use std::collections::{HashMap, VecDeque}; use crate::database::volatile; use crate::database::{KVStore, Tabular}; -use crate::game::{Bounded, DTransition, Extensive, SimpleSum, Game}; -use crate::model::SimpleUtility; +use crate::game::{Bounded, DTransition, Extensive, Game, SimpleSum}; use crate::interface::IOMode; +use crate::model::SimpleUtility; use crate::model::{PlayerCount, Remoteness, State, Turn}; use crate::solver::record::sur::RecordBuffer; use crate::solver::RecordType; @@ -31,18 +31,20 @@ const BUFFER_SIZE: usize = 128; /// The exact number of bits that are used to encode utility for one player. const UTILITY_SIZE: usize = 2; - -pub fn two_player_zero_sum_dynamic_solver(game: &G, mode: IOMode) -> Result<()> +pub fn two_player_zero_sum_dynamic_solver( + game: &G, + mode: IOMode, +) -> Result<()> where - G: DTransition + Bounded + SimpleSum<2> + Extensive<2> + Game + G: DTransition + Bounded + SimpleSum<2> + Extensive<2> + Game, { - let mut db = volatile_database(game) - .context("Failed to initialize database.")?; + let mut db = + volatile_database(game).context("Failed to initialize database.")?; basic_loopy_solver(game, &mut db)?; Ok(()) } -fn basic_loopy_solver (game: &G, db: &mut D) -> Result<()> +fn basic_loopy_solver(game: &G, db: &mut D) -> Result<()> where G: DTransition + Bounded + SimpleSum<2> + Extensive<2> + Game, D: KVStore, @@ -53,26 +55,39 @@ where let mut child_counts = HashMap::new(); - enqueue_children(&mut winning_frontier, &mut tying_frontier, &mut losing_frontier, game.start(), game, &mut child_counts, db)?; + enqueue_children( + &mut winning_frontier, + &mut tying_frontier, + &mut losing_frontier, + game.start(), + game, + &mut child_counts, + db, + )?; // Process winning and losing frontiers - while !winning_frontier.is_empty() && !losing_frontier.is_empty() && !tying_frontier.is_empty() { + while !winning_frontier.is_empty() + && !losing_frontier.is_empty() + && !tying_frontier.is_empty() + { let child = if !losing_frontier.is_empty() { - losing_frontier.pop_front().unwrap() - } else if !winning_frontier.is_empty() { - winning_frontier.pop_front().unwrap() - } else { - tying_frontier.pop_front().unwrap() - }; + losing_frontier + .pop_front() + .unwrap() + } else if !winning_frontier.is_empty() { + winning_frontier + .pop_front() + .unwrap() + } else { + tying_frontier.pop_front().unwrap() + }; let db_entry = RecordBuffer::from(db.get(child).unwrap()) .context("Failed to create record for middle state.")?; let child_utility = db_entry .get_utility(game.turn(child)) .context("Failed to get utility from record.")?; - let child_remoteness = db_entry - .get_remoteness(); - + let child_remoteness = db_entry.get_remoteness(); let parents = game.retrograde(child); // If child is a losing position @@ -97,7 +112,9 @@ where // If child is a winning position else if matches!(child_utility, SimpleUtility::WIN) { for parent in parents { - let child_count = *child_counts.get(&parent).expect("Failed to enqueue parent state in initial enqueuing stage"); + let child_count = *child_counts.get(&parent).expect( + "Failed to enqueue parent state in initial enqueuing stage", + ); // Parent has already been solved if child_count == 0 { continue; @@ -121,7 +138,9 @@ where // Child should never be a tying position else if matches!(child_utility, SimpleUtility::TIE) { for parent in parents { - let child_count = *child_counts.get(&parent).expect("Failed to enqueue parent state in initial enqueuing stage"); + let child_count = *child_counts.get(&parent).expect( + "Failed to enqueue parent state in initial enqueuing stage", + ); // Parent has already been solved if child_count == 0 { continue; @@ -138,9 +157,7 @@ where // Update child count child_counts.insert(parent, 0); } - - } - else { + } else { panic!["Position with invalid utility found in frontiers"]; } } @@ -159,15 +176,16 @@ where } /// Set up the initial frontiers and primitive position database entries -fn enqueue_children(winning_frontier: &mut VecDeque, - tying_frontier: &mut VecDeque, - losing_frontier: &mut VecDeque, - curr_state: State, - game: &G, - child_counts: &mut HashMap, - db: &mut D +fn enqueue_children( + winning_frontier: &mut VecDeque, + tying_frontier: &mut VecDeque, + losing_frontier: &mut VecDeque, + curr_state: State, + game: &G, + child_counts: &mut HashMap, + db: &mut D, ) -> Result<()> -where +where G: DTransition + Bounded + SimpleSum<2> + Extensive<2> + Game, D: KVStore, { @@ -180,11 +198,16 @@ where .context("Failed to set remoteness for end state.")?; db.put(curr_state, &buf); - match game.utility(curr_state).get(game.turn(curr_state)) { + match game + .utility(curr_state) + .get(game.turn(curr_state)) + { Some(SimpleUtility::WIN) => winning_frontier.push_back(curr_state), Some(SimpleUtility::TIE) => tying_frontier.push_back(curr_state), - Some(SimpleUtility::LOSE) => losing_frontier.push_back(curr_state), - _ => panic!["Utility for primitive ending position found to be draw"] + Some(SimpleUtility::LOSE) => losing_frontier.push_back(curr_state), + _ => { + panic!["Utility for primitive ending position found to be draw"] + }, } return Ok(()); } @@ -197,14 +220,20 @@ where if child_counts.contains_key(&child) { continue; } - enqueue_children(winning_frontier, tying_frontier, losing_frontier, child, game, child_counts, db)?; + enqueue_children( + winning_frontier, + tying_frontier, + losing_frontier, + child, + game, + child_counts, + db, + )?; } Ok(()) } - - /* DATABASE INITIALIZATION */ /// Initializes a volatile database, creating a table schema according to the @@ -230,19 +259,21 @@ where #[cfg(test)] mod tests { - use anyhow::Result; - use crate::game::{Game, GameData, Bounded, DTransition, Extensive, SimpleSum}; - use crate::model::{State, Turn}; - use std::collections::{HashMap, VecDeque}; + use crate::game::{ + Bounded, DTransition, Extensive, Game, GameData, SimpleSum, + }; use crate::interface::{IOMode, SolutionMode}; use crate::model::SimpleUtility; + use crate::model::{State, Turn}; + use anyhow::Result; + use std::collections::{HashMap, VecDeque}; use super::{enqueue_children, volatile_database}; struct GameNode { turn: Turn, utility: Vec, - children: Vec + children: Vec, } struct GameGraph { @@ -253,10 +284,10 @@ mod tests { impl Game for GameGraph { fn new(variant: Option) -> Result where - Self: Sized + Self: Sized, { unimplemented!(); - } + } fn id(&self) -> String { String::from("GameGraph") @@ -273,11 +304,13 @@ mod tests { impl Bounded for GameGraph { fn start(&self) -> u64 { - 0 + 0 } - + fn end(&self, state: State) -> bool { - self.adj_list[state as usize].children.is_empty() + self.adj_list[state as usize] + .children + .is_empty() } } @@ -296,22 +329,32 @@ mod tests { impl DTransition for GameGraph { fn prograde(&self, state: State) -> Vec { - self.adj_list[state as usize].children.clone() + self.adj_list[state as usize] + .children + .clone() } fn retrograde(&self, state: State) -> Vec { - todo![]; + todo![]; } } #[test] - fn enqueues_children_properly() -> Result<()>{ + fn enqueues_children_properly() -> Result<()> { let graph = GameGraph { num_states: 2, adj_list: vec![ - GameNode {turn: 0, utility: vec![], children: vec![1]}, - GameNode {turn: 1, utility: vec![SimpleUtility::LOSE, SimpleUtility::WIN], children: vec![]} - ] + GameNode { + turn: 0, + utility: vec![], + children: vec![1], + }, + GameNode { + turn: 1, + utility: vec![SimpleUtility::LOSE, SimpleUtility::WIN], + children: vec![], + }, + ], }; let mut db = volatile_database(&graph)?; @@ -321,7 +364,15 @@ mod tests { let mut losing_frontier = VecDeque::new(); let mut child_counts = HashMap::new(); - enqueue_children(&mut winning_frontier, &mut tying_frontier, &mut losing_frontier, graph.start(), &graph, &mut child_counts, &mut db)?; + enqueue_children( + &mut winning_frontier, + &mut tying_frontier, + &mut losing_frontier, + graph.start(), + &graph, + &mut child_counts, + &mut db, + )?; assert!(winning_frontier.is_empty()); assert!(tying_frontier.is_empty()); diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index 96bef9f..c6ad1fb 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -7,20 +7,24 @@ use anyhow::{Context, Result}; -use std::collections::{HashSet, VecDeque, HashMap}; use crate::database::volatile; use crate::database::{KVStore, Tabular}; -use crate::game::{Game, Bounded, DTransition, ClassicPuzzle, Extensive}; +use crate::game::{Bounded, ClassicPuzzle, DTransition, Extensive, Game}; use crate::interface::IOMode; +use crate::model::SimpleUtility; use crate::model::{Remoteness, State}; +use crate::solver::error::SolverError::SolverViolation; use crate::solver::record::sur::RecordBuffer; use crate::solver::RecordType; -use crate::model::SimpleUtility; -use crate::solver::error::SolverError::SolverViolation; +use std::collections::{HashMap, HashSet, VecDeque}; pub fn dynamic_solver(game: &G, mode: IOMode) -> Result<()> where - G: DTransition + Bounded + ClassicPuzzle + Extensive<1> + Game, + G: DTransition + + Bounded + + ClassicPuzzle + + Extensive<1> + + Game, { let mut db = volatile_database(game) .context("Failed to initialize volatile database.")?; @@ -31,15 +35,21 @@ where Ok(()) } - fn reverse_bfs(db: &mut D, game: &G) -> Result<()> where - G: DTransition + Bounded + ClassicPuzzle + Extensive<1> + Game, + G: DTransition + + Bounded + + ClassicPuzzle + + Extensive<1> + + Game, D: KVStore, { // Get end states and create frontiers let mut child_counts = discover_child_counts(db, game); - let end_states = child_counts.iter().filter(|&x| *x.1 == 0).map(|x| *x.0); + let end_states = child_counts + .iter() + .filter(|&x| *x.1 == 0) + .map(|x| *x.0); let mut winning_queue: VecDeque<(State, Remoteness)> = VecDeque::new(); let mut losing_queue: VecDeque<(State, Remoteness)> = VecDeque::new(); @@ -63,7 +73,6 @@ where // Perform BFS on winning states while let Some((state, remoteness)) = winning_queue.pop_front() { - let mut buf = RecordBuffer::new(1) .context("Failed to create placeholder record.")?; buf.set_utility([SimpleUtility::WIN]) @@ -107,10 +116,14 @@ where // discover state 2 for the first time. match child_counts.get(&parent) { Some(count) => child_counts.insert(parent, count - 1), - None => child_counts.insert(parent, game.prograde(parent).len() - 1), + None => { + child_counts.insert(parent, game.prograde(parent).len() - 1) + }, }; - if !visited.contains(&parent) && *child_counts.get(&state).unwrap() == 0 { + if !visited.contains(&parent) + && *child_counts.get(&state).unwrap() == 0 + { losing_queue.push_back((parent, remoteness + 1)); } } @@ -132,7 +145,11 @@ where fn discover_child_counts(db: &mut D, game: &G) -> HashMap where - G: DTransition + Bounded + ClassicPuzzle + Extensive<1> + Game, + G: DTransition + + Bounded + + ClassicPuzzle + + Extensive<1> + + Game, D: KVStore, { let mut child_counts = HashMap::new(); @@ -142,8 +159,12 @@ where child_counts } -fn discover_child_counts_helper(db: &mut D, game: &G, state: State, child_counts: &mut HashMap) -where +fn discover_child_counts_helper( + db: &mut D, + game: &G, + state: State, + child_counts: &mut HashMap, +) where G: DTransition + Bounded + ClassicPuzzle, D: KVStore, { @@ -179,24 +200,25 @@ where Ok(db) // This is only for testing purposes - } - #[cfg(test)] mod tests { - use anyhow::Result; - use crate::game::{Game, GameData, Bounded, DTransition, Extensive, ClassicPuzzle, SimpleSum}; - use crate::model::{State, Turn}; - use std::collections::{HashMap, VecDeque}; + use crate::game::{ + Bounded, ClassicPuzzle, DTransition, Extensive, Game, GameData, + SimpleSum, + }; use crate::interface::{IOMode, SolutionMode}; use crate::model::SimpleUtility; + use crate::model::{State, Turn}; + use anyhow::Result; + use std::collections::{HashMap, VecDeque}; use super::{discover_child_counts, volatile_database}; struct GameNode { utility: Option, // Is None for non-primitive puzzle nodes - children: Vec + children: Vec, } struct PuzzleGraph { @@ -212,10 +234,10 @@ mod tests { impl Game for PuzzleGraph { fn new(variant: Option) -> Result where - Self: Sized + Self: Sized, { unimplemented!(); - } + } fn id(&self) -> String { String::from("GameGraph") @@ -232,11 +254,13 @@ mod tests { impl Bounded for PuzzleGraph { fn start(&self) -> u64 { - 0 + 0 } - + fn end(&self, state: State) -> bool { - self.adj_list[state as usize].children.is_empty() + self.adj_list[state as usize] + .children + .is_empty() } } @@ -248,27 +272,43 @@ mod tests { impl ClassicPuzzle for PuzzleGraph { fn utility(&self, state: State) -> SimpleUtility { - self.adj_list[state as usize].utility.unwrap() + self.adj_list[state as usize] + .utility + .unwrap() } } impl DTransition for PuzzleGraph { fn prograde(&self, state: State) -> Vec { - self.adj_list[state as usize].children.clone() + self.adj_list[state as usize] + .children + .clone() } fn retrograde(&self, state: State) -> Vec { - (0..self.size()).filter(|&s| self.adj_list[s as usize].children.contains(&state)).collect() + (0..self.size()) + .filter(|&s| { + self.adj_list[s as usize] + .children + .contains(&state) + }) + .collect() } } #[test] - fn gets_child_counts_correctly() -> Result<()>{ + fn gets_child_counts_correctly() -> Result<()> { let graph = PuzzleGraph { adj_list: vec![ - GameNode {utility: None, children: vec![1]}, - GameNode {utility: Some(SimpleUtility::LOSE), children: vec![]} - ] + GameNode { + utility: None, + children: vec![1], + }, + GameNode { + utility: Some(SimpleUtility::LOSE), + children: vec![], + }, + ], }; let mut db = volatile_database(&graph)?; From 59c0da0a92b32ac6b85620e77898785c7af8562e Mon Sep 17 00:00:00 2001 From: Ishir Garg Date: Thu, 18 Apr 2024 00:57:33 -0700 Subject: [PATCH 06/16] added surcc record --- src/solver/algorithm/record/surcc.rs | 538 ++++++++++++++++++++++++++ src/solver/algorithm/strong/puzzle.rs | 327 +++++++++++++--- src/solver/mod.rs | 7 + 3 files changed, 827 insertions(+), 45 deletions(-) create mode 100644 src/solver/algorithm/record/surcc.rs diff --git a/src/solver/algorithm/record/surcc.rs b/src/solver/algorithm/record/surcc.rs new file mode 100644 index 0000000..580b911 --- /dev/null +++ b/src/solver/algorithm/record/surcc.rs @@ -0,0 +1,538 @@ + +//! # Simple-Utility Remoteness with Child Counts (SURCC) Record Module +//! +//! Implementation of a database record buffer for storing simple utilities +//! information of different amounts of players and the remoteness value +//! associated with a particular game state, along with the child count This is +//! mainly for internal solver use; some solving algorithms need to track child +//! counts along with each state +//! +//! #### Authorship +//! +//! - Ishir Garg, 4/1/2024 (ishirgarg@berkeley.edu) + +use anyhow::{Context, Result}; +use bitvec::field::BitField; +use bitvec::order::Msb0; +use bitvec::slice::BitSlice; +use bitvec::{bitarr, BitArr}; + +use crate::database::{Attribute, Datatype, Record, Schema, SchemaBuilder}; +use crate::model::{PlayerCount, Remoteness, SimpleUtility, Turn}; +use crate::solver::error::SolverError::RecordViolation; +use crate::solver::util; +use crate::solver::RecordType; + +/* CONSTANTS */ + +/// The exact number of bits that are used to encode remoteness. +pub const REMOTENESS_SIZE: usize = 16; + +/// The maximum number of bits that can be used to encode a single record. +pub const BUFFER_SIZE: usize = 128; + +/// The exact number of bits that are used to encode utility for one player. +pub const UTILITY_SIZE: usize = 2; + +/// The exact number of bits that are used to encode the child counts for a +/// state/record entry. +pub const CHILD_COUNT_SIZE: usize = 32; + +/// Type for child count +type ChildCount = u64; + +/* SCHEMA GENERATOR */ + +/// Return the database table schema associated with a record instance with +/// a specific number of `players` under this record implementation. +pub fn schema(players: PlayerCount) -> Result { + if RecordBuffer::bit_size(players) > BUFFER_SIZE { + Err(RecordViolation { + name: RecordBuffer::into_string(players), + hint: format!( + "This record can only hold utility values for up to {} \ + players, but there was an attempt to create a schema that \ + would represent one holding {} players.", + RecordBuffer::player_count(BUFFER_SIZE), + players + ), + })? + } else { + let mut schema = SchemaBuilder::new().of(RecordType::SUR(players)); + + for i in 0..players { + let name = &format!("P{} utility", i); + let data = Datatype::ENUM; + let size = UTILITY_SIZE; + schema = schema + .add(Attribute::new(name, data, size)) + .context( + "Failed to add utility attribute to database schema.", + )?; + } + + let name = "State remoteness"; + let data = Datatype::UINT; + let size = REMOTENESS_SIZE; + schema = schema + .add(Attribute::new(name, data, size)) + .context( + "Failed to add remoteness attribute to database schema.", + )?; + + let name = "State child count"; + let data = Datatype::UINT; + let size = CHILD_COUNT_SIZE; + schema = schema + .add(Attribute::new(name, data, size)) + .context( + "Failed to add child count attribute to database schema.", + )?; + + Ok(schema.build()) + } +} + +/* RECORD IMPLEMENTATION */ + +/// Solver-specific record entry, meant to communicate the remoteness and each +/// player's utility at a corresponding game state. The layout is as follows: +/// +/// ```none +/// [UTILITY_SIZE bits: P0 utility] +/// ... +/// [UTILITY_SIZE bits: P(N-1) utility] +/// [REMOTENESS_SIZE bits: Remoteness] +/// [CHILD_COUNT_SIZE bits: Child count] +/// [0b0 until BUFFER_SIZE] +/// ``` +/// +/// The number of players `N` is limited by `BUFFER_SIZE`, because a statically +/// sized buffer is used for intermediary storage. The utility and remoteness +/// values are encoded in big-endian, with utility being a signed two's +/// complement integer and remoteness an unsigned integer. +pub struct RecordBuffer { + buf: BitArr!(for BUFFER_SIZE, in u8, Msb0), + players: PlayerCount, +} + +impl Record for RecordBuffer { + #[inline(always)] + fn raw(&self) -> &BitSlice { + &self.buf[..Self::bit_size(self.players)] + } +} + +impl RecordBuffer { + // Returns the string name for this record buffer + fn into_string(players: PlayerCount) -> String { + format!("Simple Utility Remoteness Child Count ({} players)", players) + } + + /// Returns a new instance of a bit-packed record buffer that is able to + /// store utility values for `players`. Fails if `players` is too high for + /// the underlying buffer's capacity. + #[inline(always)] + pub fn new(players: PlayerCount) -> Result { + if Self::bit_size(players) > BUFFER_SIZE { + Err(RecordViolation { + name: RecordBuffer::into_string(players), + hint: format!( + "The record can only hold utility values for up to {} \ + players, but there was an attempt to instantiate one for \ + {} players.", + Self::player_count(BUFFER_SIZE), + players + ), + })? + } else { + Ok(Self { + buf: bitarr!(u8, Msb0; 0; BUFFER_SIZE), + players, + }) + } + } + + /// Return a new instance with `bits` as the underlying buffer. Fails in the + /// event that the size of `bits` is incoherent with the record. + #[inline(always)] + pub fn from(bits: &BitSlice) -> Result { + let len = bits.len(); + if len > BUFFER_SIZE { + Err(RecordViolation { + name: RecordBuffer::into_string(0), + hint: format!( + "The record implementation operates on a buffer of {} \ + bits, but there was an attempt to instantiate one from a \ + buffer of {} bits.", + BUFFER_SIZE, len, + ), + })? + } else if len < Self::minimum_bit_size() { + Err(RecordViolation { + name: RecordBuffer::into_string(0), + hint: format!( + "This record implementation stores utility values, but \ + there was an attempt to instantiate one with from a buffer \ + with {} bits, which is not enough to store a remoteness \ + value (which takes {} bits).", + len, REMOTENESS_SIZE, + ), + })? + } else { + let players = Self::player_count(len); + let mut buf = bitarr!(u8, Msb0; 0; BUFFER_SIZE); + buf[..len].copy_from_bitslice(bits); + Ok(Self { players, buf }) + } + } + + /* GET METHODS */ + + /// Parse and return the utility value corresponding to `player`. Fails if + /// the `player` index passed in is incoherent with player count. + #[inline(always)] + pub fn get_utility(&self, player: Turn) -> Result { + if player >= self.players { + Err(RecordViolation { + name: RecordBuffer::into_string(self.players), + hint: format!( + "A record was instantiated with {} utility entries, and \ + there was an attempt to fetch the utility of player {} \ + (0-indexed) from that record instance.", + self.players, player, + ), + })? + } else { + let start = Self::utility_index(player); + let end = start + UTILITY_SIZE; + let val = self.buf[start..end].load_be::(); + if let Ok(utility) = SimpleUtility::try_from(val) { + Ok(utility) + } else { + Err(RecordViolation { + name: RecordBuffer::into_string(self.players), + hint: format!( + "There was an attempt to deserialize a utility value \ + of '{}' into a simple utility type.", + val, + ), + })? + } + } + } + + /// Parse and return the remoteness value in the record encoding. Failure + /// here indicates corrupted state. + #[inline(always)] + pub fn get_remoteness(&self) -> Remoteness { + let start = Self::remoteness_index(self.players); + let end = start + REMOTENESS_SIZE; + self.buf[start..end].load_be::() + } + + /// Parse and return the child count value in the record encoding. Failure + /// here indicates corrupted state. + #[inline(always)] + pub fn get_child_count(&self) -> ChildCount { + let start = Self::child_count_index(self.players); + let end = start + CHILD_COUNT_SIZE; + self.buf[start..end].load_be::() + } + + /* SET METHODS */ + + /// Set this entry to have the utility values in `v` for each player. Fails + /// if any of the utility values are too high to fit in the space dedicated + /// for each player's utility, or if there is a mismatch between player + /// count and the number of utility values passed in. + #[inline(always)] + pub fn set_utility( + &mut self, + v: [SimpleUtility; N], + ) -> Result<()> { + if N != self.players { + Err(RecordViolation { + name: RecordBuffer::into_string(self.players), + hint: format!( + "A record was instantiated with {} utility entries, and \ + there was an attempt to use a {}-entry utility list to \ + update the record utility values.", + self.players, N, + ), + })? + } else { + for player in 0..self.players { + let utility = v[player] as u64; + let size = util::min_ubits(utility); + if size > UTILITY_SIZE { + Err(RecordViolation { + name: RecordBuffer::into_string(self.players), + hint: format!( + "This record implementation uses {} bits to store \ + signed integers representing utility values, but \ + there was an attempt to store a utility of {}, \ + which requires at least {} bits to store.", + UTILITY_SIZE, utility, size, + ), + })? + } + + let start = Self::utility_index(player); + let end = start + UTILITY_SIZE; + self.buf[start..end].store_be(utility); + } + Ok(()) + } + } + + /// Set this entry to have `value` remoteness. Fails if `value` is too high + /// to fit in the space dedicated for remoteness within the record. + #[inline(always)] + pub fn set_remoteness(&mut self, value: Remoteness) -> Result<()> { + let size = util::min_ubits(value); + if size > REMOTENESS_SIZE { + Err(RecordViolation { + name: RecordBuffer::into_string(self.players), + hint: format!( + "This record implementation uses {} bits to store unsigned \ + integers representing remoteness values, but there was an \ + attempt to store a remoteness value of {}, which requires \ + at least {} bits to store.", + REMOTENESS_SIZE, value, size, + ), + })? + } else { + let start = Self::remoteness_index(self.players); + let end = start + REMOTENESS_SIZE; + self.buf[start..end].store_be(value); + Ok(()) + } + } + + /// Set this entry to have `value` child count. Fails if `value` is too high + /// to fit in the space dedicated for child count within the record. + #[inline(always)] + pub fn set_child_count(&mut self, value: ChildCount) -> Result<()> { + let size = util::min_ubits(value); + if size > CHILD_COUNT_SIZE { + Err(RecordViolation { + name: RecordBuffer::into_string(self.players), + hint: format!( + "This record implementation uses {} bits to store unsigned \ + integers representing child count values, but there was an \ + attempt to store a child count value of {}, which requires \ + at least {} bits to store.", + CHILD_COUNT_SIZE, value, size, + ), + })? + } else { + let start = Self::child_count_index(self.players); + let end = start + CHILD_COUNT_SIZE; + self.buf[start..end].store_be(value); + Ok(()) + } + } + + /* LAYOUT HELPER METHODS */ + + /// Return the number of bits that would be needed to store a record + /// containing utility information for `players` as well as remoteness. + #[inline(always)] + const fn bit_size(players: usize) -> usize { + (players * UTILITY_SIZE) + REMOTENESS_SIZE + } + + /// Return the minimum number of bits needed for a valid record buffer. + #[inline(always)] + const fn minimum_bit_size() -> usize { + REMOTENESS_SIZE + } + + /// Return the bit index of the remoteness entry start in the record buffer. + #[inline(always)] + const fn remoteness_index(players: usize) -> usize { + players * UTILITY_SIZE + } + + /// Return the bit index of the 'i'th player's utility entry start. + #[inline(always)] + const fn utility_index(player: Turn) -> usize { + player * UTILITY_SIZE + } + + /// Return the bit index of the child count entry + #[inline(always)] + const fn child_count_index(players: usize) -> usize { + players * UTILITY_SIZE + REMOTENESS_SIZE + } + + /// Return the maximum number of utility entries supported by a dense record + /// (one that maximizes bit usage) with `length`. Ignores unused bits. + #[inline(always)] + const fn player_count(length: usize) -> usize { + (length - REMOTENESS_SIZE) / UTILITY_SIZE + } +} + +#[cfg(test)] +mod tests { + + use super::*; + // The maximum and minimum numeric values that can be represented with + // exactly UTILITY_SIZE bits in two's complement. + // + // Example if UTILITY_SIZE is 8: + // + // * `MAX_UTILITY = 0b01111111 = 127 = 2^(8 - 1) - 1` + // * `MIN_UTILITY = 0b10000000 = -128 = -127 - 1` + // + // Useful: https://www.omnicalculator.com/math/twos-complement + const MAX_UTILITY: SimpleUtility = SimpleUtility::TIE; + const MIN_UTILITY: SimpleUtility = SimpleUtility::WIN; + + // The maximum numeric remoteness value that can be expressed with exactly + // REMOTENESS_SIZE bits in an unsigned integer. + const MAX_REMOTENESS: Remoteness = 2_u64.pow(REMOTENESS_SIZE as u32) - 1; + const MAX_CHILD_COUNT: ChildCount = 2_u64.pow(CHILD_COUNT_SIZE as u32) - 1; + + #[test] + fn initialize_with_valid_player_count() { + for i in 0..=RecordBuffer::player_count(BUFFER_SIZE) { + assert!(RecordBuffer::new(i).is_ok()) + } + } + + #[test] + fn initialize_with_invalid_player_count() { + let max = RecordBuffer::player_count(BUFFER_SIZE); + + assert!(RecordBuffer::new(max + 1).is_err()); + assert!(RecordBuffer::new(max + 10).is_err()); + assert!(RecordBuffer::new(max + 100).is_err()); + } + + #[test] + fn initialize_from_valid_buffer() { + let buf = bitarr!(u8, Msb0; 0; BUFFER_SIZE); + for i in REMOTENESS_SIZE..BUFFER_SIZE { + assert!(RecordBuffer::from(&buf[0..i]).is_ok()); + } + } + + #[test] + fn initialize_from_invalid_buffer() { + let buf1 = bitarr!(u8, Msb0; 0; BUFFER_SIZE + 1); + let buf2 = bitarr!(u8, Msb0; 0; BUFFER_SIZE + 10); + let buf3 = bitarr!(u8, Msb0; 0; BUFFER_SIZE + 100); + + assert!(RecordBuffer::from(&buf1).is_err()); + assert!(RecordBuffer::from(&buf2).is_err()); + assert!(RecordBuffer::from(&buf3).is_err()); + } + + #[test] + fn set_record_attributes() { + let mut r1 = RecordBuffer::new(7).unwrap(); + let mut r2 = RecordBuffer::new(4).unwrap(); + let mut r3 = RecordBuffer::new(0).unwrap(); + + let v1 = [SimpleUtility::WIN; 7]; + let v2 = [SimpleUtility::TIE; 4]; + let v3: [SimpleUtility; 0] = []; + + let v4 = [MAX_UTILITY; 7]; + let v5 = [MIN_UTILITY; 4]; + let v6 = [SimpleUtility::DRAW]; + + let good = Remoteness::MIN; + let bad = Remoteness::MAX; + + assert!(r1.set_utility(v1).is_ok()); + assert!(r2.set_utility(v2).is_ok()); + assert!(r3.set_utility(v3).is_ok()); + + assert!(r1.set_utility(v4).is_ok()); + assert!(r2.set_utility(v5).is_ok()); + assert!(r3.set_utility(v6).is_err()); + + assert!(r1.set_remoteness(good).is_ok()); + assert!(r2.set_remoteness(good).is_ok()); + assert!(r3.set_remoteness(good).is_ok()); + + assert!(r1.set_remoteness(bad).is_err()); + assert!(r2.set_remoteness(bad).is_err()); + assert!(r3.set_remoteness(bad).is_err()); + } + + #[test] + fn data_is_valid_after_round_trip() { + let mut record = RecordBuffer::new(5).unwrap(); + let payoffs = [ + SimpleUtility::LOSE, + SimpleUtility::WIN, + SimpleUtility::LOSE, + SimpleUtility::LOSE, + SimpleUtility::LOSE, + ]; + let remoteness = 790; + + record + .set_utility(payoffs) + .unwrap(); + + record + .set_remoteness(remoteness) + .unwrap(); + + // Utilities unchanged after insert and fetch + for i in 0..5 { + let fetched_utility = record.get_utility(i).unwrap(); + let actual_utility = payoffs[i]; + assert!(matches!(fetched_utility, actual_utility)); + } + + // Remoteness unchanged after insert and fetch + let fetched_remoteness = record.get_remoteness(); + let actual_remoteness = remoteness; + assert_eq!(fetched_remoteness, actual_remoteness); + + // Fetching utility entries of invalid players + assert!(record.get_utility(5).is_err()); + assert!(record.get_utility(10).is_err()); + } + + #[test] + fn extreme_data_is_valid_after_round_trip() { + let mut record = RecordBuffer::new(6).unwrap(); + + let good = [ + SimpleUtility::WIN, + SimpleUtility::LOSE, + SimpleUtility::TIE, + SimpleUtility::TIE, + SimpleUtility::DRAW, + SimpleUtility::WIN, + ]; + + let bad = [ + SimpleUtility::DRAW, + SimpleUtility::WIN, + SimpleUtility::TIE, + ]; + + assert!(record.set_utility(good).is_ok()); + assert!(record + .set_remoteness(MAX_REMOTENESS) + .is_ok()); + + for i in 0..6 { + let fetched_utility = record.get_utility(i).unwrap(); + let actual_utility = good[i]; + assert!(matches!(fetched_utility, actual_utility)); + } + + assert_eq!(record.get_remoteness(), MAX_REMOTENESS); + assert!(record.set_utility(bad).is_err()); + } +} diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index c6ad1fb..980d386 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -14,11 +14,11 @@ use crate::interface::IOMode; use crate::model::SimpleUtility; use crate::model::{Remoteness, State}; use crate::solver::error::SolverError::SolverViolation; -use crate::solver::record::sur::RecordBuffer; -use crate::solver::RecordType; use std::collections::{HashMap, HashSet, VecDeque}; +use bitvec::{order::Msb0, prelude::*, slice::BitSlice, store::BitStore}; +use crate::solver::algorithm::record::surcc::RecordBuffer; -pub fn dynamic_solver(game: &G, mode: IOMode) -> Result<()> +pub fn dynamic_solver(game: &G, mode: IOMode) -> Result<()> where G: DTransition + Bounded @@ -29,13 +29,13 @@ where let mut db = volatile_database(game) .context("Failed to initialize volatile database.")?; - reverse_bfs(&mut db, game) + reverse_bfs_solver(&mut db, game) .context("Failed solving algorithm execution.")?; Ok(()) } -fn reverse_bfs(db: &mut D, game: &G) -> Result<()> +fn reverse_bfs_solver(db: &mut D, game: &G) -> Result<()> where G: DTransition + Bounded @@ -46,11 +46,16 @@ where { // Get end states and create frontiers let mut child_counts = discover_child_counts(db, game); + // Get all states with 0 child count (primitive states) let end_states = child_counts .iter() .filter(|&x| *x.1 == 0) .map(|x| *x.0); + // Contains states that have already been visited + let mut visited = HashSet::new(); + + // TODO: Change this to no longer store remoteness, just query db let mut winning_queue: VecDeque<(State, Remoteness)> = VecDeque::new(); let mut losing_queue: VecDeque<(State, Remoteness)> = VecDeque::new(); for end_state in end_states { @@ -66,11 +71,9 @@ where hint: format!("Primitive end position cannot have utility DRAW for a puzzle"), })?, } + visited.insert(end_state); } - // Contains states that have already been visited - let mut visited = HashSet::new(); - // Perform BFS on winning states while let Some((state, remoteness)) = winning_queue.pop_front() { let mut buf = RecordBuffer::new(1) @@ -81,18 +84,19 @@ where .context("Failed to set remoteness for state.")?; db.put(state, &buf); + // Zero out child counts so it doesn't get detected as draw child_counts.insert(state, 0); - visited.insert(state); - let parents = game.retrograde(state); - for parent in parents { + for parent in game.retrograde(state) { if !visited.contains(&parent) { winning_queue.push_back((parent, remoteness + 1)); + visited.insert(parent); } } } - // Perform BFS on losing states + // Perform BFS on losing states, where remoteness is the longest path to a losing primitive + // position. while let Some((state, remoteness)) = losing_queue.pop_front() { let mut buf = RecordBuffer::new(1) .context("Failed to create placeholder record.")?; @@ -102,29 +106,17 @@ where .context("Failed to set remoteness for state.")?; db.put(state, &buf); - visited.insert(state); let parents = game.retrograde(state); for parent in parents { - // The check below is needed, because it is theoretically possible - // for child_counts to NOT contain a position discovered by - // retrograde(). Consider a 3-node game tree with starting vertex 1, - // and edges (1 -> 2), (3 -> 2), where 2 is a losing primitive - // ending position. In this case, running discover_child_counts() on - // 1 above only gets child_counts for states 1 and 2, however - // calling retrograde on end state 2 in this BFS portion will - // discover state 2 for the first time. - match child_counts.get(&parent) { - Some(count) => child_counts.insert(parent, count - 1), - None => { - child_counts.insert(parent, game.prograde(parent).len() - 1) - }, - }; - - if !visited.contains(&parent) - && *child_counts.get(&state).unwrap() == 0 - { - losing_queue.push_back((parent, remoteness + 1)); + if !visited.contains(&parent) { + let new_child_count = *child_counts.get(&parent).unwrap() - 1; + child_counts.insert(parent, new_child_count); + + if new_child_count == 0 { + losing_queue.push_back((parent, remoteness + 1)); + visited.insert(parent); + } } } } @@ -170,7 +162,10 @@ fn discover_child_counts_helper( { child_counts.insert(state, game.prograde(state).len()); - for child in game.prograde(state) { + // We need to check both prograde and retrograde; consider a game with 3 nodes where 0-->2 + // and 1-->2. Then, starting from node 0 with only progrades would discover states 0 and 1; we + // need to include retrogrades to discover state 2. + for &child in game.prograde(state).iter().chain(game.retrograde(state).iter()) { if !child_counts.contains_key(&child) { discover_child_counts_helper(db, game, child, child_counts); } @@ -182,12 +177,15 @@ fn discover_child_counts_helper( /// Initializes a volatile database, creating a table schema according to the /// solver record layout, initializing a table with that schema, and switching /// to that table before returning the database handle. + +/* fn volatile_database(game: &G) -> Result where G: Extensive + Game, { let id = game.id(); let db = volatile::Database::initialize(); + let db = TestDB::initialize(); let schema = RecordType::SUR(1) .try_into() @@ -199,7 +197,47 @@ where Ok(db) - // This is only for testing purposes +} +*/ + +// THIS IS ONLY FOR TESTING PURPOSES +struct TestDB { + memory: HashMap> +} + +impl TestDB { + fn initialize() -> Self { + Self { + memory: HashMap::new() + } + } +} + +impl KVStore for TestDB { + fn put(&mut self, key: State, record: &R) { + let new = BitVec::from(record.raw()).clone(); + self.memory.insert(key, new); + } + + fn get(&self, key: State) -> Option<&bitvec::prelude::BitSlice> { + let vec_opt = self.memory.get(&key); + match vec_opt { + None => None, + Some(vect) => Some(&vect[..]), + } + } + + fn del(&mut self, key: State) { + unimplemented![]; + } +} + +fn volatile_database(game: &G) -> Result +where + G: Extensive + Game, +{ + let db = TestDB::initialize(); + Ok(db) } #[cfg(test)] @@ -213,12 +251,16 @@ mod tests { use crate::model::{State, Turn}; use anyhow::Result; use std::collections::{HashMap, VecDeque}; + use crate::solver::record::sur::RecordBuffer; + use crate::database::{KVStore, Tabular}; + use crate::game::mock; + use crate::node; - use super::{discover_child_counts, volatile_database}; + use super::{discover_child_counts, volatile_database, reverse_bfs_solver, TestDB}; struct GameNode { - utility: Option, // Is None for non-primitive puzzle nodes children: Vec, + utility: Option, // Is None for non-primitive puzzle nodes } struct PuzzleGraph { @@ -297,26 +339,221 @@ mod tests { } #[test] - fn gets_child_counts_correctly() -> Result<()> { + fn game_with_single_node_win() -> Result<()> { let graph = PuzzleGraph { adj_list: vec![ - GameNode { - utility: None, - children: vec![1], - }, - GameNode { - utility: Some(SimpleUtility::LOSE), - children: vec![], - }, + GameNode { children: vec![], utility: Some(SimpleUtility::WIN) } ], }; + // Check child counts let mut db = volatile_database(&graph)?; + let child_counts = discover_child_counts(&mut db, &graph); + + assert_eq!(child_counts, HashMap::from([(0, 0)])); + + // Solve game + let mut db = volatile_database(&graph)?; + reverse_bfs_solver(&mut db, &graph); + + matches!(RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 0); + + Ok(()) + } + + #[test] + fn game_with_two_nodes_win() -> Result<()> { + let graph = PuzzleGraph { + adj_list: vec![ + GameNode { children: vec![1], utility: None }, + GameNode { children: vec![], utility: Some(SimpleUtility::WIN) }, + ], + }; + // Check child counts + let mut db = volatile_database(&graph)?; let child_counts = discover_child_counts(&mut db, &graph); assert_eq!(child_counts, HashMap::from([(0, 1), (1, 0)])); + // Solve game + let mut db = volatile_database(&graph)?; + reverse_bfs_solver(&mut db, &graph); + + matches!(RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + matches!(RecordBuffer::from(db.get(1).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + + assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 1); + assert_eq!(RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), 0); + + Ok(()) + } + + #[test] + fn game_with_dag_win() -> Result<()> { + let graph = PuzzleGraph { + adj_list: vec![ + GameNode { children: vec![1, 2, 4], utility: None }, + GameNode { children: vec![3], utility: None }, + GameNode { children: vec![3, 4], utility: None }, + GameNode { children: vec![4], utility: None }, + GameNode { children: vec![], utility: Some(SimpleUtility::WIN) }, + ], + }; + + // Check child counts + let mut db = volatile_database(&graph)?; + let child_counts = discover_child_counts(&mut db, &graph); + + assert_eq!(child_counts, HashMap::from([(0, 3), (1, 1), (2, 2), (3, 1), (4, 0)])); + + // Solve game + let mut db = volatile_database(&graph)?; + reverse_bfs_solver(&mut db, &graph); + + for i in 0..5 { + matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + } + + assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 1); + assert_eq!(RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), 2); + assert_eq!(RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), 1); + assert_eq!(RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), 1); + assert_eq!(RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), 0); + + Ok(()) + } + + #[test] + fn game_with_cyclic_graph_draw() -> Result<()> { + let graph = PuzzleGraph { + adj_list: vec![ + GameNode { children: vec![1, 2, 4], utility: None }, + GameNode { children: vec![3], utility: None }, + GameNode { children: vec![3, 4], utility: None }, + GameNode { children: vec![4], utility: None }, + GameNode { children: vec![5], utility: None }, + GameNode { children: vec![2, 4], utility: None }, + ], + }; + + + // Check child counts + let mut db = volatile_database(&graph)?; + let child_counts = discover_child_counts(&mut db, &graph); + + assert_eq!(child_counts, HashMap::from([(0, 3), (1, 1), (2, 2), (3, 1), (4, 1), (5, 2)])); + + // Solve game + let mut db = volatile_database(&graph)?; + reverse_bfs_solver(&mut db, &graph); + + for i in 0..5 { + matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::TIE); + } + + Ok(()) + } + + #[test] + fn game_with_dag_win_and_lose() -> Result<()> { + let graph = PuzzleGraph { + adj_list: vec![ + GameNode { utility: None, children: vec![3] }, + GameNode { utility: None, children: vec![4] }, + GameNode { utility: None, children: vec![4] }, + GameNode { utility: None, children: vec![4, 5] }, + GameNode { utility: None, children: vec![8, 0] }, + GameNode { utility: Some(SimpleUtility::WIN), children: vec![] }, + GameNode { utility: None, children: vec![8] }, + GameNode { utility: None, children: vec![6, 8] }, + GameNode { utility: Some(SimpleUtility::LOSE), children: vec![] }, + ], + }; + + // Check child counts + let mut db = volatile_database(&graph)?; + let child_counts = discover_child_counts(&mut db, &graph); + + assert_eq!(child_counts, HashMap::from([(0, 1), (1, 1), (2, 1), (3, 2), (4, 2), (5, 0), (6, 1), (7, 2), (8, 0)])); + + // Solve game + let mut db = volatile_database(&graph)?; + reverse_bfs_solver(&mut db, &graph); + + for i in 0..=5 { + matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + } + matches!(RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, SimpleUtility::LOSE); + matches!(RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, SimpleUtility::LOSE); + matches!(RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, SimpleUtility::LOSE); + + assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 2); + assert_eq!(RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), 4); + assert_eq!(RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), 4); + assert_eq!(RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), 1); + assert_eq!(RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), 3); + assert_eq!(RecordBuffer::from(db.get(5).unwrap())?.get_remoteness(), 0); + assert_eq!(RecordBuffer::from(db.get(6).unwrap())?.get_remoteness(), 1); + assert_eq!(RecordBuffer::from(db.get(7).unwrap())?.get_remoteness(), 2); + assert_eq!(RecordBuffer::from(db.get(8).unwrap())?.get_remoteness(), 0); + + Ok(()) + } + + #[test] + fn game_with_wld() -> Result<()> { + let graph = PuzzleGraph { + adj_list: vec![ + GameNode { utility: None, children: vec![3] }, + GameNode { utility: None, children: vec![4, 5] }, + GameNode { utility: None, children: vec![4] }, + GameNode { utility: None, children: vec![4, 5] }, + GameNode { utility: None, children: vec![8, 0] }, + GameNode { utility: Some(SimpleUtility::WIN), children: vec![] }, + + GameNode { utility: None, children: vec![8] }, + GameNode { utility: None, children: vec![6, 8, 13] }, + GameNode { utility: Some(SimpleUtility::LOSE), children: vec![] }, + + GameNode { utility: Some(SimpleUtility::LOSE), children: vec![10] }, + GameNode { utility: Some(SimpleUtility::LOSE), children: vec![11] }, + GameNode { utility: Some(SimpleUtility::LOSE), children: vec![9, 2] }, + + GameNode { utility: Some(SimpleUtility::LOSE), children: vec![7] }, + GameNode { utility: Some(SimpleUtility::LOSE), children: vec![12] }, + ], + }; + + // Solve game + let mut db = volatile_database(&graph)?; + reverse_bfs_solver(&mut db, &graph); + + for i in 0..=5 { + matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + } + matches!(RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, SimpleUtility::LOSE); + matches!(RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, SimpleUtility::DRAW); + matches!(RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, SimpleUtility::LOSE); + for i in 9..=11 { + matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + } + matches!(RecordBuffer::from(db.get(12).unwrap())?.get_utility(0)?, SimpleUtility::DRAW); + matches!(RecordBuffer::from(db.get(13).unwrap())?.get_utility(0)?, SimpleUtility::DRAW); + + assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 2); + assert_eq!(RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), 1); + assert_eq!(RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), 4); + assert_eq!(RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), 1); + assert_eq!(RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), 3); + assert_eq!(RecordBuffer::from(db.get(5).unwrap())?.get_remoteness(), 0); + assert_eq!(RecordBuffer::from(db.get(6).unwrap())?.get_remoteness(), 1); + assert_eq!(RecordBuffer::from(db.get(8).unwrap())?.get_remoteness(), 0); + assert_eq!(RecordBuffer::from(db.get(9).unwrap())?.get_remoteness(), 7); + assert_eq!(RecordBuffer::from(db.get(10).unwrap())?.get_remoteness(), 6); + assert_eq!(RecordBuffer::from(db.get(11).unwrap())?.get_remoteness(), 5); + Ok(()) } } diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 65d7651..c01108f 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -61,6 +61,13 @@ pub mod algorithm { pub mod acyclic; pub mod cyclic; } + + /// These are custom records for certain solving algorithms that may need to + /// store additional data; these should not be accessible outside the + /// solving algorithms + mod record { + pub mod surcc; + } } #[cfg(test)] From 589ed4a833b97e1270c620a481193a29a9fd7fab Mon Sep 17 00:00:00 2001 From: Ishir Garg Date: Thu, 18 Apr 2024 03:23:53 -0700 Subject: [PATCH 07/16] puzzle solver now uses surcc record --- src/solver/algorithm/strong/puzzle.rs | 208 ++++++++++----------- src/solver/mod.rs | 11 +- src/solver/{algorithm => }/record/surcc.rs | 49 ++--- src/solver/util.rs | 4 + 4 files changed, 130 insertions(+), 142 deletions(-) rename src/solver/{algorithm => }/record/surcc.rs (94%) diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index 980d386..1c3f052 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -16,7 +16,7 @@ use crate::model::{Remoteness, State}; use crate::solver::error::SolverError::SolverViolation; use std::collections::{HashMap, HashSet, VecDeque}; use bitvec::{order::Msb0, prelude::*, slice::BitSlice, store::BitStore}; -use crate::solver::algorithm::record::surcc::RecordBuffer; +use crate::solver::record::surcc::{ChildCount, RecordBuffer}; pub fn dynamic_solver(game: &G, mode: IOMode) -> Result<()> where @@ -45,23 +45,18 @@ where D: KVStore, { // Get end states and create frontiers - let mut child_counts = discover_child_counts(db, game); - // Get all states with 0 child count (primitive states) - let end_states = child_counts - .iter() - .filter(|&x| *x.1 == 0) - .map(|x| *x.0); + let end_states = discover_child_counts(db, game)?; // Contains states that have already been visited let mut visited = HashSet::new(); // TODO: Change this to no longer store remoteness, just query db - let mut winning_queue: VecDeque<(State, Remoteness)> = VecDeque::new(); - let mut losing_queue: VecDeque<(State, Remoteness)> = VecDeque::new(); + let mut winning_queue: VecDeque = VecDeque::new(); + let mut losing_queue: VecDeque = VecDeque::new(); for end_state in end_states { match ClassicPuzzle::utility(game, end_state) { - SimpleUtility::WIN => winning_queue.push_back((end_state, 0)), - SimpleUtility::LOSE => losing_queue.push_back((end_state, 0)), + SimpleUtility::WIN => winning_queue.push_back(end_state), + SimpleUtility::LOSE => losing_queue.push_back(end_state), SimpleUtility::TIE => Err(SolverViolation { name: "PuzzleSolver".to_string(), hint: format!("Primitive end position cannot have utility TIE for a puzzle"), @@ -72,104 +67,123 @@ where })?, } visited.insert(end_state); + + // Add ending state utility and remoteness to database + update_db_record(db, end_state, game.utility(end_state), 0, 0)?; } // Perform BFS on winning states - while let Some((state, remoteness)) = winning_queue.pop_front() { - let mut buf = RecordBuffer::new(1) - .context("Failed to create placeholder record.")?; - buf.set_utility([SimpleUtility::WIN]) - .context("Failed to set remoteness for state.")?; - buf.set_remoteness(remoteness) - .context("Failed to set remoteness for state.")?; - db.put(state, &buf); - - // Zero out child counts so it doesn't get detected as draw - child_counts.insert(state, 0); + while let Some(state) = winning_queue.pop_front() { + let child_remoteness = RecordBuffer::from(db.get(state).unwrap())?.get_remoteness(); for parent in game.retrograde(state) { if !visited.contains(&parent) { - winning_queue.push_back((parent, remoteness + 1)); + winning_queue.push_back(parent); visited.insert(parent); + update_db_record(db, parent, SimpleUtility::WIN, 1 + child_remoteness, 0)?; } } } // Perform BFS on losing states, where remoteness is the longest path to a losing primitive // position. - while let Some((state, remoteness)) = losing_queue.pop_front() { - let mut buf = RecordBuffer::new(1) - .context("Failed to create placeholder record.")?; - buf.set_utility([SimpleUtility::LOSE]) - .context("Failed to set remoteness for state.")?; - buf.set_remoteness(remoteness) - .context("Failed to set remoteness for state.")?; - db.put(state, &buf); - + while let Some(state) = losing_queue.pop_front() { let parents = game.retrograde(state); + let child_remoteness = RecordBuffer::from(db.get(state).unwrap())?.get_remoteness(); for parent in parents { if !visited.contains(&parent) { - let new_child_count = *child_counts.get(&parent).unwrap() - 1; - child_counts.insert(parent, new_child_count); - + // Get child count from database + let mut buf = RecordBuffer::from(db.get(parent).unwrap()) + .context("Failed to get record for middle state")?; + + let new_child_count = buf.get_child_count() - 1; + buf.set_child_count(new_child_count)?; + db.put(parent, &buf); + + // If all children have been solved, set this state as a losing state if new_child_count == 0 { - losing_queue.push_back((parent, remoteness + 1)); + losing_queue.push_back(parent); visited.insert(parent); + update_db_record(db, parent, SimpleUtility::LOSE, 1 + child_remoteness, 0)?; } } } } - // Get remaining draw positions - for (state, count) in child_counts { - if count > 0 { - let mut buf = RecordBuffer::new(1) - .context("Failed to create placeholder record.")?; - buf.set_utility([SimpleUtility::DRAW]) - .context("Failed to set remoteness for state.")?; - db.put(state, &buf); - } - } + Ok(()) +} + +/// Updates the database record for a puzzle with given simple utility and remoteness +fn update_db_record(db: &mut D, state: State, utility: SimpleUtility, remoteness: Remoteness, child_count: ChildCount) -> Result<()> +where + D: KVStore, +{ + let mut buf = RecordBuffer::from(db.get(state).unwrap()) + .context("Failed to create record for middle state")?; + buf.set_utility([utility]) + .context("Failed to set utility for state.")?; + buf.set_remoteness(remoteness) + .context("Failed to set remoteness for state.")?; + buf.set_child_count(child_count) + .context("Failed to set child count for state.")?; + db.put(state, &buf); Ok(()) } -fn discover_child_counts(db: &mut D, game: &G) -> HashMap +fn discover_child_counts( + db: &mut D, + game: &G, +) -> Result> where - G: DTransition - + Bounded - + ClassicPuzzle - + Extensive<1> - + Game, + G: DTransition + Bounded + ClassicPuzzle + Extensive<1> + Game, D: KVStore, { - let mut child_counts = HashMap::new(); + let mut end_states = Vec::new(); + discover_child_counts_helper(db, game, game.start(), &mut end_states)?; - discover_child_counts_helper(db, game, game.start(), &mut child_counts); - - child_counts + Ok(end_states) } +/// Adds child counts for each position to the database +/// Also returns a vector of all primitive positions fn discover_child_counts_helper( db: &mut D, game: &G, state: State, - child_counts: &mut HashMap, -) where - G: DTransition + Bounded + ClassicPuzzle, + end_states: &mut Vec +) -> Result<()> +where + G: DTransition + Bounded + ClassicPuzzle + Extensive<1> + Game, D: KVStore, { - child_counts.insert(state, game.prograde(state).len()); + let child_count = game.prograde(state).len() as ChildCount; + + if child_count == 0 { + end_states.push(state); + } + + // Initialize all utilies to draw; any utilities not set by the end must be + // a drawn position + let mut buf = RecordBuffer::new(1) + .context("Failed to create record for state")?; + buf.set_utility([SimpleUtility::DRAW]) + .context("Failed to set remoteness for state")?; + buf.set_child_count(child_count) + .context("Failed to set child count for state.")?; + db.put(state, &buf); // We need to check both prograde and retrograde; consider a game with 3 nodes where 0-->2 // and 1-->2. Then, starting from node 0 with only progrades would discover states 0 and 1; we // need to include retrogrades to discover state 2. for &child in game.prograde(state).iter().chain(game.retrograde(state).iter()) { - if !child_counts.contains_key(&child) { - discover_child_counts_helper(db, game, child, child_counts); + if db.get(child).is_none() { + discover_child_counts_helper(db, game, child, end_states)?; } } + + Ok(()) } /* DATABASE INITIALIZATION */ @@ -185,7 +199,6 @@ where { let id = game.id(); let db = volatile::Database::initialize(); - let db = TestDB::initialize(); let schema = RecordType::SUR(1) .try_into() @@ -251,7 +264,7 @@ mod tests { use crate::model::{State, Turn}; use anyhow::Result; use std::collections::{HashMap, VecDeque}; - use crate::solver::record::sur::RecordBuffer; + use crate::solver::record::surcc::RecordBuffer; use crate::database::{KVStore, Tabular}; use crate::game::mock; use crate::node; @@ -345,18 +358,12 @@ mod tests { GameNode { children: vec![], utility: Some(SimpleUtility::WIN) } ], }; - - // Check child counts - let mut db = volatile_database(&graph)?; - let child_counts = discover_child_counts(&mut db, &graph); - - assert_eq!(child_counts, HashMap::from([(0, 0)])); - + // Solve game let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - matches!(RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + assert!(matches!(RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 0); Ok(()) @@ -371,18 +378,12 @@ mod tests { ], }; - // Check child counts - let mut db = volatile_database(&graph)?; - let child_counts = discover_child_counts(&mut db, &graph); - - assert_eq!(child_counts, HashMap::from([(0, 1), (1, 0)])); - // Solve game let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - matches!(RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN); - matches!(RecordBuffer::from(db.get(1).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + assert!(matches!(RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); + assert!(matches!(RecordBuffer::from(db.get(1).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 1); assert_eq!(RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), 0); @@ -402,18 +403,12 @@ mod tests { ], }; - // Check child counts - let mut db = volatile_database(&graph)?; - let child_counts = discover_child_counts(&mut db, &graph); - - assert_eq!(child_counts, HashMap::from([(0, 3), (1, 1), (2, 2), (3, 1), (4, 0)])); - // Solve game let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); for i in 0..5 { - matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + assert!(matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); } assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 1); @@ -438,19 +433,12 @@ mod tests { ], }; - - // Check child counts - let mut db = volatile_database(&graph)?; - let child_counts = discover_child_counts(&mut db, &graph); - - assert_eq!(child_counts, HashMap::from([(0, 3), (1, 1), (2, 2), (3, 1), (4, 1), (5, 2)])); - // Solve game let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); for i in 0..5 { - matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::TIE); + assert!(matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::DRAW)); } Ok(()) @@ -472,22 +460,16 @@ mod tests { ], }; - // Check child counts - let mut db = volatile_database(&graph)?; - let child_counts = discover_child_counts(&mut db, &graph); - - assert_eq!(child_counts, HashMap::from([(0, 1), (1, 1), (2, 1), (3, 2), (4, 2), (5, 0), (6, 1), (7, 2), (8, 0)])); - - // Solve game + // Solve game let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); for i in 0..=5 { - matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + assert!(matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); } - matches!(RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, SimpleUtility::LOSE); - matches!(RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, SimpleUtility::LOSE); - matches!(RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, SimpleUtility::LOSE); + assert!(matches!(RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, SimpleUtility::LOSE)); + assert!(matches!(RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, SimpleUtility::LOSE)); + assert!(matches!(RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, SimpleUtility::LOSE)); assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 2); assert_eq!(RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), 4); @@ -531,16 +513,16 @@ mod tests { reverse_bfs_solver(&mut db, &graph); for i in 0..=5 { - matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + assert!(matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); } - matches!(RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, SimpleUtility::LOSE); - matches!(RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, SimpleUtility::DRAW); - matches!(RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, SimpleUtility::LOSE); + assert!(matches!(RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, SimpleUtility::LOSE)); + assert!(matches!(RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, SimpleUtility::DRAW)); + assert!(matches!(RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, SimpleUtility::LOSE)); for i in 9..=11 { - matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN); + assert!(matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); } - matches!(RecordBuffer::from(db.get(12).unwrap())?.get_utility(0)?, SimpleUtility::DRAW); - matches!(RecordBuffer::from(db.get(13).unwrap())?.get_utility(0)?, SimpleUtility::DRAW); + assert!(matches!(RecordBuffer::from(db.get(12).unwrap())?.get_utility(0)?, SimpleUtility::DRAW)); + assert!(matches!(RecordBuffer::from(db.get(13).unwrap())?.get_utility(0)?, SimpleUtility::DRAW)); assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 2); assert_eq!(RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), 1); diff --git a/src/solver/mod.rs b/src/solver/mod.rs index c01108f..2e08961 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -29,6 +29,7 @@ pub mod record { pub mod mur; pub mod sur; pub mod rem; + pub mod surcc; } /// Implementations of algorithms that can consume game implementations and @@ -61,13 +62,6 @@ pub mod algorithm { pub mod acyclic; pub mod cyclic; } - - /// These are custom records for certain solving algorithms that may need to - /// store additional data; these should not be accessible outside the - /// solving algorithms - mod record { - pub mod surcc; - } } #[cfg(test)] @@ -88,4 +82,7 @@ pub enum RecordType { SUR(PlayerCount), /// Remoteness record (no utilities). REM, + /// Simple Utility Remoteness with Child Counts records for a specific + /// number of players + SURCC(PlayerCount), } diff --git a/src/solver/algorithm/record/surcc.rs b/src/solver/record/surcc.rs similarity index 94% rename from src/solver/algorithm/record/surcc.rs rename to src/solver/record/surcc.rs index 580b911..35bd57b 100644 --- a/src/solver/algorithm/record/surcc.rs +++ b/src/solver/record/surcc.rs @@ -39,7 +39,7 @@ pub const UTILITY_SIZE: usize = 2; pub const CHILD_COUNT_SIZE: usize = 32; /// Type for child count -type ChildCount = u64; +pub type ChildCount = u64; /* SCHEMA GENERATOR */ @@ -48,7 +48,7 @@ type ChildCount = u64; pub fn schema(players: PlayerCount) -> Result { if RecordBuffer::bit_size(players) > BUFFER_SIZE { Err(RecordViolation { - name: RecordBuffer::into_string(players), + name: RecordType::SURCC(players).into(), hint: format!( "This record can only hold utility values for up to {} \ players, but there was an attempt to create a schema that \ @@ -58,7 +58,7 @@ pub fn schema(players: PlayerCount) -> Result { ), })? } else { - let mut schema = SchemaBuilder::new().of(RecordType::SUR(players)); + let mut schema = SchemaBuilder::new().of(RecordType::SURCC(players)); for i in 0..players { let name = &format!("P{} utility", i); @@ -124,11 +124,6 @@ impl Record for RecordBuffer { } impl RecordBuffer { - // Returns the string name for this record buffer - fn into_string(players: PlayerCount) -> String { - format!("Simple Utility Remoteness Child Count ({} players)", players) - } - /// Returns a new instance of a bit-packed record buffer that is able to /// store utility values for `players`. Fails if `players` is too high for /// the underlying buffer's capacity. @@ -136,7 +131,7 @@ impl RecordBuffer { pub fn new(players: PlayerCount) -> Result { if Self::bit_size(players) > BUFFER_SIZE { Err(RecordViolation { - name: RecordBuffer::into_string(players), + name: RecordType::SURCC(players).into(), hint: format!( "The record can only hold utility values for up to {} \ players, but there was an attempt to instantiate one for \ @@ -160,7 +155,7 @@ impl RecordBuffer { let len = bits.len(); if len > BUFFER_SIZE { Err(RecordViolation { - name: RecordBuffer::into_string(0), + name: RecordType::SURCC(0).into(), hint: format!( "The record implementation operates on a buffer of {} \ bits, but there was an attempt to instantiate one from a \ @@ -170,13 +165,13 @@ impl RecordBuffer { })? } else if len < Self::minimum_bit_size() { Err(RecordViolation { - name: RecordBuffer::into_string(0), + name: RecordType::SURCC(0).into(), hint: format!( "This record implementation stores utility values, but \ there was an attempt to instantiate one with from a buffer \ with {} bits, which is not enough to store a remoteness \ - value (which takes {} bits).", - len, REMOTENESS_SIZE, + and child count value (which takes {} bits).", + len, Self::minimum_bit_size(), ), })? } else { @@ -195,7 +190,7 @@ impl RecordBuffer { pub fn get_utility(&self, player: Turn) -> Result { if player >= self.players { Err(RecordViolation { - name: RecordBuffer::into_string(self.players), + name: RecordType::SURCC(self.players).into(), hint: format!( "A record was instantiated with {} utility entries, and \ there was an attempt to fetch the utility of player {} \ @@ -211,7 +206,7 @@ impl RecordBuffer { Ok(utility) } else { Err(RecordViolation { - name: RecordBuffer::into_string(self.players), + name: RecordType::SURCC(self.players).into(), hint: format!( "There was an attempt to deserialize a utility value \ of '{}' into a simple utility type.", @@ -253,7 +248,7 @@ impl RecordBuffer { ) -> Result<()> { if N != self.players { Err(RecordViolation { - name: RecordBuffer::into_string(self.players), + name: RecordType::SURCC(self.players).into(), hint: format!( "A record was instantiated with {} utility entries, and \ there was an attempt to use a {}-entry utility list to \ @@ -267,7 +262,7 @@ impl RecordBuffer { let size = util::min_ubits(utility); if size > UTILITY_SIZE { Err(RecordViolation { - name: RecordBuffer::into_string(self.players), + name: RecordType::SURCC(self.players).into(), hint: format!( "This record implementation uses {} bits to store \ signed integers representing utility values, but \ @@ -293,7 +288,7 @@ impl RecordBuffer { let size = util::min_ubits(value); if size > REMOTENESS_SIZE { Err(RecordViolation { - name: RecordBuffer::into_string(self.players), + name: RecordType::SURCC(self.players).into(), hint: format!( "This record implementation uses {} bits to store unsigned \ integers representing remoteness values, but there was an \ @@ -317,7 +312,7 @@ impl RecordBuffer { let size = util::min_ubits(value); if size > CHILD_COUNT_SIZE { Err(RecordViolation { - name: RecordBuffer::into_string(self.players), + name: RecordType::SURCC(self.players).into(), hint: format!( "This record implementation uses {} bits to store unsigned \ integers representing child count values, but there was an \ @@ -340,13 +335,13 @@ impl RecordBuffer { /// containing utility information for `players` as well as remoteness. #[inline(always)] const fn bit_size(players: usize) -> usize { - (players * UTILITY_SIZE) + REMOTENESS_SIZE + (players * UTILITY_SIZE) + REMOTENESS_SIZE + CHILD_COUNT_SIZE } /// Return the minimum number of bits needed for a valid record buffer. #[inline(always)] const fn minimum_bit_size() -> usize { - REMOTENESS_SIZE + REMOTENESS_SIZE + CHILD_COUNT_SIZE } /// Return the bit index of the remoteness entry start in the record buffer. @@ -371,7 +366,7 @@ impl RecordBuffer { /// (one that maximizes bit usage) with `length`. Ignores unused bits. #[inline(always)] const fn player_count(length: usize) -> usize { - (length - REMOTENESS_SIZE) / UTILITY_SIZE + (length - REMOTENESS_SIZE - CHILD_COUNT_SIZE) / UTILITY_SIZE } } @@ -535,4 +530,14 @@ mod tests { assert_eq!(record.get_remoteness(), MAX_REMOTENESS); assert!(record.set_utility(bad).is_err()); } + + #[test] + fn child_counts_retrieved_properly() -> Result<()> { + let mut buf = RecordBuffer::new(3)?; + buf.set_child_count(4)?; + + assert_eq!(buf.get_child_count(), 4); + + Ok(()) + } } diff --git a/src/solver/util.rs b/src/solver/util.rs index c0085b9..071d685 100644 --- a/src/solver/util.rs +++ b/src/solver/util.rs @@ -44,6 +44,9 @@ impl Into for RecordType { RecordType::REM => { format!("Remoteness (no utility)") }, + RecordType::SURCC(players) => { + format!("Simple Utility Remoteness with Child Count ({} players)", players) + }, } } } @@ -56,6 +59,7 @@ impl TryInto for RecordType { RecordType::RUR(players) => record::mur::schema(players), RecordType::SUR(players) => record::sur::schema(players), RecordType::REM => record::rem::schema(), + RecordType::SURCC(players) => record::surcc::schema(players), } } } From 0796ec318d24979a5a231c565e5a312488477066 Mon Sep 17 00:00:00 2001 From: Ishir Garg Date: Thu, 18 Apr 2024 03:28:59 -0700 Subject: [PATCH 08/16] improved puzzle solver --- src/solver/algorithm/strong/puzzle.rs | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index 1c3f052..f7ba20b 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -47,10 +47,6 @@ where // Get end states and create frontiers let end_states = discover_child_counts(db, game)?; - // Contains states that have already been visited - let mut visited = HashSet::new(); - - // TODO: Change this to no longer store remoteness, just query db let mut winning_queue: VecDeque = VecDeque::new(); let mut losing_queue: VecDeque = VecDeque::new(); for end_state in end_states { @@ -66,20 +62,19 @@ where hint: format!("Primitive end position cannot have utility DRAW for a puzzle"), })?, } - visited.insert(end_state); - // Add ending state utility and remoteness to database update_db_record(db, end_state, game.utility(end_state), 0, 0)?; } // Perform BFS on winning states while let Some(state) = winning_queue.pop_front() { + let buf = RecordBuffer::from(db.get(state).unwrap())?; let child_remoteness = RecordBuffer::from(db.get(state).unwrap())?.get_remoteness(); for parent in game.retrograde(state) { - if !visited.contains(&parent) { + let child_count = RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); + if child_count > 0 { winning_queue.push_back(parent); - visited.insert(parent); update_db_record(db, parent, SimpleUtility::WIN, 1 + child_remoteness, 0)?; } } @@ -92,11 +87,11 @@ where let child_remoteness = RecordBuffer::from(db.get(state).unwrap())?.get_remoteness(); for parent in parents { - if !visited.contains(&parent) { - // Get child count from database + let child_count = RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); + if child_count > 0 { + // Update child count let mut buf = RecordBuffer::from(db.get(parent).unwrap()) .context("Failed to get record for middle state")?; - let new_child_count = buf.get_child_count() - 1; buf.set_child_count(new_child_count)?; db.put(parent, &buf); @@ -104,7 +99,6 @@ where // If all children have been solved, set this state as a losing state if new_child_count == 0 { losing_queue.push_back(parent); - visited.insert(parent); update_db_record(db, parent, SimpleUtility::LOSE, 1 + child_remoteness, 0)?; } } From 35dee4fac916e7887b11faf3f8f84be00dfdb85f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 18 Apr 2024 10:29:19 +0000 Subject: [PATCH 09/16] Format Rust code using rustfmt --- src/solver/algorithm/strong/puzzle.rs | 524 ++++++++++++++++++++------ src/solver/record/surcc.rs | 6 +- src/solver/util.rs | 5 +- 3 files changed, 406 insertions(+), 129 deletions(-) diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index f7ba20b..21689ea 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -14,9 +14,9 @@ use crate::interface::IOMode; use crate::model::SimpleUtility; use crate::model::{Remoteness, State}; use crate::solver::error::SolverError::SolverViolation; -use std::collections::{HashMap, HashSet, VecDeque}; -use bitvec::{order::Msb0, prelude::*, slice::BitSlice, store::BitStore}; use crate::solver::record::surcc::{ChildCount, RecordBuffer}; +use bitvec::{order::Msb0, prelude::*, slice::BitSlice, store::BitStore}; +use std::collections::{HashMap, HashSet, VecDeque}; pub fn dynamic_solver(game: &G, mode: IOMode) -> Result<()> where @@ -69,13 +69,21 @@ where // Perform BFS on winning states while let Some(state) = winning_queue.pop_front() { let buf = RecordBuffer::from(db.get(state).unwrap())?; - let child_remoteness = RecordBuffer::from(db.get(state).unwrap())?.get_remoteness(); + let child_remoteness = + RecordBuffer::from(db.get(state).unwrap())?.get_remoteness(); for parent in game.retrograde(state) { - let child_count = RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); + let child_count = + RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); if child_count > 0 { winning_queue.push_back(parent); - update_db_record(db, parent, SimpleUtility::WIN, 1 + child_remoteness, 0)?; + update_db_record( + db, + parent, + SimpleUtility::WIN, + 1 + child_remoteness, + 0, + )?; } } } @@ -84,10 +92,12 @@ where // position. while let Some(state) = losing_queue.pop_front() { let parents = game.retrograde(state); - let child_remoteness = RecordBuffer::from(db.get(state).unwrap())?.get_remoteness(); + let child_remoteness = + RecordBuffer::from(db.get(state).unwrap())?.get_remoteness(); for parent in parents { - let child_count = RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); + let child_count = + RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); if child_count > 0 { // Update child count let mut buf = RecordBuffer::from(db.get(parent).unwrap()) @@ -99,7 +109,13 @@ where // If all children have been solved, set this state as a losing state if new_child_count == 0 { losing_queue.push_back(parent); - update_db_record(db, parent, SimpleUtility::LOSE, 1 + child_remoteness, 0)?; + update_db_record( + db, + parent, + SimpleUtility::LOSE, + 1 + child_remoteness, + 0, + )?; } } } @@ -109,8 +125,14 @@ where } /// Updates the database record for a puzzle with given simple utility and remoteness -fn update_db_record(db: &mut D, state: State, utility: SimpleUtility, remoteness: Remoteness, child_count: ChildCount) -> Result<()> -where +fn update_db_record( + db: &mut D, + state: State, + utility: SimpleUtility, + remoteness: Remoteness, + child_count: ChildCount, +) -> Result<()> +where D: KVStore, { let mut buf = RecordBuffer::from(db.get(state).unwrap()) @@ -126,12 +148,13 @@ where Ok(()) } -fn discover_child_counts( - db: &mut D, - game: &G, -) -> Result> +fn discover_child_counts(db: &mut D, game: &G) -> Result> where - G: DTransition + Bounded + ClassicPuzzle + Extensive<1> + Game, + G: DTransition + + Bounded + + ClassicPuzzle + + Extensive<1> + + Game, D: KVStore, { let mut end_states = Vec::new(); @@ -146,22 +169,26 @@ fn discover_child_counts_helper( db: &mut D, game: &G, state: State, - end_states: &mut Vec + end_states: &mut Vec, ) -> Result<()> where - G: DTransition + Bounded + ClassicPuzzle + Extensive<1> + Game, + G: DTransition + + Bounded + + ClassicPuzzle + + Extensive<1> + + Game, D: KVStore, { let child_count = game.prograde(state).len() as ChildCount; if child_count == 0 { - end_states.push(state); + end_states.push(state); } // Initialize all utilies to draw; any utilities not set by the end must be // a drawn position - let mut buf = RecordBuffer::new(1) - .context("Failed to create record for state")?; + let mut buf = + RecordBuffer::new(1).context("Failed to create record for state")?; buf.set_utility([SimpleUtility::DRAW]) .context("Failed to set remoteness for state")?; buf.set_child_count(child_count) @@ -171,7 +198,11 @@ where // We need to check both prograde and retrograde; consider a game with 3 nodes where 0-->2 // and 1-->2. Then, starting from node 0 with only progrades would discover states 0 and 1; we // need to include retrogrades to discover state 2. - for &child in game.prograde(state).iter().chain(game.retrograde(state).iter()) { + for &child in game + .prograde(state) + .iter() + .chain(game.retrograde(state).iter()) + { if db.get(child).is_none() { discover_child_counts_helper(db, game, child, end_states)?; } @@ -209,13 +240,13 @@ where // THIS IS ONLY FOR TESTING PURPOSES struct TestDB { - memory: HashMap> + memory: HashMap>, } impl TestDB { fn initialize() -> Self { Self { - memory: HashMap::new() + memory: HashMap::new(), } } } @@ -224,9 +255,12 @@ impl KVStore for TestDB { fn put(&mut self, key: State, record: &R) { let new = BitVec::from(record.raw()).clone(); self.memory.insert(key, new); - } + } - fn get(&self, key: State) -> Option<&bitvec::prelude::BitSlice> { + fn get( + &self, + key: State, + ) -> Option<&bitvec::prelude::BitSlice> { let vec_opt = self.memory.get(&key); match vec_opt { None => None, @@ -235,7 +269,7 @@ impl KVStore for TestDB { } fn del(&mut self, key: State) { - unimplemented![]; + unimplemented![]; } } @@ -249,6 +283,8 @@ where #[cfg(test)] mod tests { + use crate::database::{KVStore, Tabular}; + use crate::game::mock; use crate::game::{ Bounded, ClassicPuzzle, DTransition, Extensive, Game, GameData, SimpleSum, @@ -256,14 +292,14 @@ mod tests { use crate::interface::{IOMode, SolutionMode}; use crate::model::SimpleUtility; use crate::model::{State, Turn}; + use crate::node; + use crate::solver::record::surcc::RecordBuffer; use anyhow::Result; use std::collections::{HashMap, VecDeque}; - use crate::solver::record::surcc::RecordBuffer; - use crate::database::{KVStore, Tabular}; - use crate::game::mock; - use crate::node; - use super::{discover_child_counts, volatile_database, reverse_bfs_solver, TestDB}; + use super::{ + discover_child_counts, reverse_bfs_solver, volatile_database, TestDB, + }; struct GameNode { children: Vec, @@ -348,18 +384,25 @@ mod tests { #[test] fn game_with_single_node_win() -> Result<()> { let graph = PuzzleGraph { - adj_list: vec![ - GameNode { children: vec![], utility: Some(SimpleUtility::WIN) } - ], + adj_list: vec![GameNode { + children: vec![], + utility: Some(SimpleUtility::WIN), + }], }; - + // Solve game let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - assert!(matches!(RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); - assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 0); - + assert!(matches!( + RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, + SimpleUtility::WIN + )); + assert_eq!( + RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), + 0 + ); + Ok(()) } @@ -367,8 +410,14 @@ mod tests { fn game_with_two_nodes_win() -> Result<()> { let graph = PuzzleGraph { adj_list: vec![ - GameNode { children: vec![1], utility: None }, - GameNode { children: vec![], utility: Some(SimpleUtility::WIN) }, + GameNode { + children: vec![1], + utility: None, + }, + GameNode { + children: vec![], + utility: Some(SimpleUtility::WIN), + }, ], }; @@ -376,11 +425,23 @@ mod tests { let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - assert!(matches!(RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); - assert!(matches!(RecordBuffer::from(db.get(1).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); - - assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 1); - assert_eq!(RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), 0); + assert!(matches!( + RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, + SimpleUtility::WIN + )); + assert!(matches!( + RecordBuffer::from(db.get(1).unwrap())?.get_utility(0)?, + SimpleUtility::WIN + )); + + assert_eq!( + RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), + 1 + ); + assert_eq!( + RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), + 0 + ); Ok(()) } @@ -389,11 +450,26 @@ mod tests { fn game_with_dag_win() -> Result<()> { let graph = PuzzleGraph { adj_list: vec![ - GameNode { children: vec![1, 2, 4], utility: None }, - GameNode { children: vec![3], utility: None }, - GameNode { children: vec![3, 4], utility: None }, - GameNode { children: vec![4], utility: None }, - GameNode { children: vec![], utility: Some(SimpleUtility::WIN) }, + GameNode { + children: vec![1, 2, 4], + utility: None, + }, + GameNode { + children: vec![3], + utility: None, + }, + GameNode { + children: vec![3, 4], + utility: None, + }, + GameNode { + children: vec![4], + utility: None, + }, + GameNode { + children: vec![], + utility: Some(SimpleUtility::WIN), + }, ], }; @@ -402,14 +478,32 @@ mod tests { reverse_bfs_solver(&mut db, &graph); for i in 0..5 { - assert!(matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); + assert!(matches!( + RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, + SimpleUtility::WIN + )); } - assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 1); - assert_eq!(RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), 2); - assert_eq!(RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), 1); - assert_eq!(RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), 1); - assert_eq!(RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), 0); + assert_eq!( + RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), + 1 + ); + assert_eq!( + RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), + 2 + ); + assert_eq!( + RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), + 1 + ); + assert_eq!( + RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), + 1 + ); + assert_eq!( + RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), + 0 + ); Ok(()) } @@ -418,12 +512,30 @@ mod tests { fn game_with_cyclic_graph_draw() -> Result<()> { let graph = PuzzleGraph { adj_list: vec![ - GameNode { children: vec![1, 2, 4], utility: None }, - GameNode { children: vec![3], utility: None }, - GameNode { children: vec![3, 4], utility: None }, - GameNode { children: vec![4], utility: None }, - GameNode { children: vec![5], utility: None }, - GameNode { children: vec![2, 4], utility: None }, + GameNode { + children: vec![1, 2, 4], + utility: None, + }, + GameNode { + children: vec![3], + utility: None, + }, + GameNode { + children: vec![3, 4], + utility: None, + }, + GameNode { + children: vec![4], + utility: None, + }, + GameNode { + children: vec![5], + utility: None, + }, + GameNode { + children: vec![2, 4], + utility: None, + }, ], }; @@ -432,7 +544,10 @@ mod tests { reverse_bfs_solver(&mut db, &graph); for i in 0..5 { - assert!(matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::DRAW)); + assert!(matches!( + RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, + SimpleUtility::DRAW + )); } Ok(()) @@ -442,39 +557,105 @@ mod tests { fn game_with_dag_win_and_lose() -> Result<()> { let graph = PuzzleGraph { adj_list: vec![ - GameNode { utility: None, children: vec![3] }, - GameNode { utility: None, children: vec![4] }, - GameNode { utility: None, children: vec![4] }, - GameNode { utility: None, children: vec![4, 5] }, - GameNode { utility: None, children: vec![8, 0] }, - GameNode { utility: Some(SimpleUtility::WIN), children: vec![] }, - GameNode { utility: None, children: vec![8] }, - GameNode { utility: None, children: vec![6, 8] }, - GameNode { utility: Some(SimpleUtility::LOSE), children: vec![] }, + GameNode { + utility: None, + children: vec![3], + }, + GameNode { + utility: None, + children: vec![4], + }, + GameNode { + utility: None, + children: vec![4], + }, + GameNode { + utility: None, + children: vec![4, 5], + }, + GameNode { + utility: None, + children: vec![8, 0], + }, + GameNode { + utility: Some(SimpleUtility::WIN), + children: vec![], + }, + GameNode { + utility: None, + children: vec![8], + }, + GameNode { + utility: None, + children: vec![6, 8], + }, + GameNode { + utility: Some(SimpleUtility::LOSE), + children: vec![], + }, ], }; - // Solve game + // Solve game let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); for i in 0..=5 { - assert!(matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); + assert!(matches!( + RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, + SimpleUtility::WIN + )); } - assert!(matches!(RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, SimpleUtility::LOSE)); - assert!(matches!(RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, SimpleUtility::LOSE)); - assert!(matches!(RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, SimpleUtility::LOSE)); - - assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 2); - assert_eq!(RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), 4); - assert_eq!(RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), 4); - assert_eq!(RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), 1); - assert_eq!(RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), 3); - assert_eq!(RecordBuffer::from(db.get(5).unwrap())?.get_remoteness(), 0); - assert_eq!(RecordBuffer::from(db.get(6).unwrap())?.get_remoteness(), 1); - assert_eq!(RecordBuffer::from(db.get(7).unwrap())?.get_remoteness(), 2); - assert_eq!(RecordBuffer::from(db.get(8).unwrap())?.get_remoteness(), 0); - + assert!(matches!( + RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, + SimpleUtility::LOSE + )); + assert!(matches!( + RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, + SimpleUtility::LOSE + )); + assert!(matches!( + RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, + SimpleUtility::LOSE + )); + + assert_eq!( + RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), + 2 + ); + assert_eq!( + RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), + 4 + ); + assert_eq!( + RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), + 4 + ); + assert_eq!( + RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), + 1 + ); + assert_eq!( + RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), + 3 + ); + assert_eq!( + RecordBuffer::from(db.get(5).unwrap())?.get_remoteness(), + 0 + ); + assert_eq!( + RecordBuffer::from(db.get(6).unwrap())?.get_remoteness(), + 1 + ); + assert_eq!( + RecordBuffer::from(db.get(7).unwrap())?.get_remoteness(), + 2 + ); + assert_eq!( + RecordBuffer::from(db.get(8).unwrap())?.get_remoteness(), + 0 + ); + Ok(()) } @@ -482,23 +663,62 @@ mod tests { fn game_with_wld() -> Result<()> { let graph = PuzzleGraph { adj_list: vec![ - GameNode { utility: None, children: vec![3] }, - GameNode { utility: None, children: vec![4, 5] }, - GameNode { utility: None, children: vec![4] }, - GameNode { utility: None, children: vec![4, 5] }, - GameNode { utility: None, children: vec![8, 0] }, - GameNode { utility: Some(SimpleUtility::WIN), children: vec![] }, - - GameNode { utility: None, children: vec![8] }, - GameNode { utility: None, children: vec![6, 8, 13] }, - GameNode { utility: Some(SimpleUtility::LOSE), children: vec![] }, - - GameNode { utility: Some(SimpleUtility::LOSE), children: vec![10] }, - GameNode { utility: Some(SimpleUtility::LOSE), children: vec![11] }, - GameNode { utility: Some(SimpleUtility::LOSE), children: vec![9, 2] }, - - GameNode { utility: Some(SimpleUtility::LOSE), children: vec![7] }, - GameNode { utility: Some(SimpleUtility::LOSE), children: vec![12] }, + GameNode { + utility: None, + children: vec![3], + }, + GameNode { + utility: None, + children: vec![4, 5], + }, + GameNode { + utility: None, + children: vec![4], + }, + GameNode { + utility: None, + children: vec![4, 5], + }, + GameNode { + utility: None, + children: vec![8, 0], + }, + GameNode { + utility: Some(SimpleUtility::WIN), + children: vec![], + }, + GameNode { + utility: None, + children: vec![8], + }, + GameNode { + utility: None, + children: vec![6, 8, 13], + }, + GameNode { + utility: Some(SimpleUtility::LOSE), + children: vec![], + }, + GameNode { + utility: Some(SimpleUtility::LOSE), + children: vec![10], + }, + GameNode { + utility: Some(SimpleUtility::LOSE), + children: vec![11], + }, + GameNode { + utility: Some(SimpleUtility::LOSE), + children: vec![9, 2], + }, + GameNode { + utility: Some(SimpleUtility::LOSE), + children: vec![7], + }, + GameNode { + utility: Some(SimpleUtility::LOSE), + children: vec![12], + }, ], }; @@ -507,29 +727,83 @@ mod tests { reverse_bfs_solver(&mut db, &graph); for i in 0..=5 { - assert!(matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); + assert!(matches!( + RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, + SimpleUtility::WIN + )); } - assert!(matches!(RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, SimpleUtility::LOSE)); - assert!(matches!(RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, SimpleUtility::DRAW)); - assert!(matches!(RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, SimpleUtility::LOSE)); + assert!(matches!( + RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, + SimpleUtility::LOSE + )); + assert!(matches!( + RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, + SimpleUtility::DRAW + )); + assert!(matches!( + RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, + SimpleUtility::LOSE + )); for i in 9..=11 { - assert!(matches!(RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::WIN)); + assert!(matches!( + RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, + SimpleUtility::WIN + )); } - assert!(matches!(RecordBuffer::from(db.get(12).unwrap())?.get_utility(0)?, SimpleUtility::DRAW)); - assert!(matches!(RecordBuffer::from(db.get(13).unwrap())?.get_utility(0)?, SimpleUtility::DRAW)); - - assert_eq!(RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 2); - assert_eq!(RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), 1); - assert_eq!(RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), 4); - assert_eq!(RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), 1); - assert_eq!(RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), 3); - assert_eq!(RecordBuffer::from(db.get(5).unwrap())?.get_remoteness(), 0); - assert_eq!(RecordBuffer::from(db.get(6).unwrap())?.get_remoteness(), 1); - assert_eq!(RecordBuffer::from(db.get(8).unwrap())?.get_remoteness(), 0); - assert_eq!(RecordBuffer::from(db.get(9).unwrap())?.get_remoteness(), 7); - assert_eq!(RecordBuffer::from(db.get(10).unwrap())?.get_remoteness(), 6); - assert_eq!(RecordBuffer::from(db.get(11).unwrap())?.get_remoteness(), 5); - + assert!(matches!( + RecordBuffer::from(db.get(12).unwrap())?.get_utility(0)?, + SimpleUtility::DRAW + )); + assert!(matches!( + RecordBuffer::from(db.get(13).unwrap())?.get_utility(0)?, + SimpleUtility::DRAW + )); + + assert_eq!( + RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), + 2 + ); + assert_eq!( + RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), + 1 + ); + assert_eq!( + RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), + 4 + ); + assert_eq!( + RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), + 1 + ); + assert_eq!( + RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), + 3 + ); + assert_eq!( + RecordBuffer::from(db.get(5).unwrap())?.get_remoteness(), + 0 + ); + assert_eq!( + RecordBuffer::from(db.get(6).unwrap())?.get_remoteness(), + 1 + ); + assert_eq!( + RecordBuffer::from(db.get(8).unwrap())?.get_remoteness(), + 0 + ); + assert_eq!( + RecordBuffer::from(db.get(9).unwrap())?.get_remoteness(), + 7 + ); + assert_eq!( + RecordBuffer::from(db.get(10).unwrap())?.get_remoteness(), + 6 + ); + assert_eq!( + RecordBuffer::from(db.get(11).unwrap())?.get_remoteness(), + 5 + ); + Ok(()) } } diff --git a/src/solver/record/surcc.rs b/src/solver/record/surcc.rs index 35bd57b..56a6132 100644 --- a/src/solver/record/surcc.rs +++ b/src/solver/record/surcc.rs @@ -1,4 +1,3 @@ - //! # Simple-Utility Remoteness with Child Counts (SURCC) Record Module //! //! Implementation of a database record buffer for storing simple utilities @@ -171,7 +170,8 @@ impl RecordBuffer { there was an attempt to instantiate one with from a buffer \ with {} bits, which is not enough to store a remoteness \ and child count value (which takes {} bits).", - len, Self::minimum_bit_size(), + len, + Self::minimum_bit_size(), ), })? } else { @@ -356,7 +356,7 @@ impl RecordBuffer { player * UTILITY_SIZE } - /// Return the bit index of the child count entry + /// Return the bit index of the child count entry #[inline(always)] const fn child_count_index(players: usize) -> usize { players * UTILITY_SIZE + REMOTENESS_SIZE diff --git a/src/solver/util.rs b/src/solver/util.rs index 071d685..f28cd39 100644 --- a/src/solver/util.rs +++ b/src/solver/util.rs @@ -45,7 +45,10 @@ impl Into for RecordType { format!("Remoteness (no utility)") }, RecordType::SURCC(players) => { - format!("Simple Utility Remoteness with Child Count ({} players)", players) + format!( + "Simple Utility Remoteness with Child Count ({} players)", + players + ) }, } } From d0ab99aadc2b4671a2f19bf2fa675d50141ed984 Mon Sep 17 00:00:00 2001 From: Ishir Garg Date: Thu, 18 Apr 2024 03:34:20 -0700 Subject: [PATCH 10/16] formatting fix --- src/solver/algorithm/strong/puzzle.rs | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index 21689ea..542d414 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -55,11 +55,13 @@ where SimpleUtility::LOSE => losing_queue.push_back(end_state), SimpleUtility::TIE => Err(SolverViolation { name: "PuzzleSolver".to_string(), - hint: format!("Primitive end position cannot have utility TIE for a puzzle"), + hint: format!("Primitive end position cannot have utility TIE + for a puzzle"), })?, SimpleUtility::DRAW => Err(SolverViolation { name: "PuzzleSolver".to_string(), - hint: format!("Primitive end position cannot have utility DRAW for a puzzle"), + hint: format!("Primitive end position cannot have utility DRAW + for a puzzle"), })?, } // Add ending state utility and remoteness to database @@ -88,7 +90,8 @@ where } } - // Perform BFS on losing states, where remoteness is the longest path to a losing primitive + // Perform BFS on losing states, where remoteness is the longest path to a + // losing primitive // position. while let Some(state) = losing_queue.pop_front() { let parents = game.retrograde(state); @@ -106,7 +109,8 @@ where buf.set_child_count(new_child_count)?; db.put(parent, &buf); - // If all children have been solved, set this state as a losing state + // If all children have been solved, set this state as a losing + // state if new_child_count == 0 { losing_queue.push_back(parent); update_db_record( @@ -124,7 +128,8 @@ where Ok(()) } -/// Updates the database record for a puzzle with given simple utility and remoteness +/// Updates the database record for a puzzle with given simple utility, +/// remoteness, and child count fn update_db_record( db: &mut D, state: State, @@ -195,9 +200,10 @@ where .context("Failed to set child count for state.")?; db.put(state, &buf); - // We need to check both prograde and retrograde; consider a game with 3 nodes where 0-->2 - // and 1-->2. Then, starting from node 0 with only progrades would discover states 0 and 1; we - // need to include retrogrades to discover state 2. + // We need to check both prograde and retrograde; consider a game with 3 + // nodes where 0-->2 and 1-->2. Then, starting from node 0 with only + // progrades would discover states 0 and 1; we need to include retrogrades + // to discover state 2. for &child in game .prograde(state) .iter() From d1456c90ad5e9a1ca0bdc52984a005bab4201c4e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 18 Apr 2024 10:34:34 +0000 Subject: [PATCH 11/16] Format Rust code using rustfmt --- src/solver/algorithm/strong/puzzle.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index 542d414..a42cb2d 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -55,13 +55,17 @@ where SimpleUtility::LOSE => losing_queue.push_back(end_state), SimpleUtility::TIE => Err(SolverViolation { name: "PuzzleSolver".to_string(), - hint: format!("Primitive end position cannot have utility TIE - for a puzzle"), + hint: format!( + "Primitive end position cannot have utility TIE + for a puzzle" + ), })?, SimpleUtility::DRAW => Err(SolverViolation { name: "PuzzleSolver".to_string(), - hint: format!("Primitive end position cannot have utility DRAW - for a puzzle"), + hint: format!( + "Primitive end position cannot have utility DRAW + for a puzzle" + ), })?, } // Add ending state utility and remoteness to database From 350e03144852e2991eb3a9de0b9a795c9900c1a4 Mon Sep 17 00:00:00 2001 From: Ishir Garg Date: Mon, 22 Apr 2024 18:40:56 -0700 Subject: [PATCH 12/16] Game utility interfaces require Extensive trait to be implemented. Upstream implementation of Extensive trait for ClassicPuzzle --- src/game/mod.rs | 29 +++++++++++++++++++++++---- src/solver/algorithm/strong/puzzle.rs | 28 +++++++------------------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/src/game/mod.rs b/src/game/mod.rs index 9aeaa89..d9e5ccd 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -289,7 +289,10 @@ where /* UTILITY INTERFACES */ /// TODO -pub trait GeneralSum { +pub trait GeneralSum +where + Self: Extensive, +{ /// If `state` is terminal, returns the utility vector associated with that /// state, where `utility[i]` is the utility of the state for player `i`. If /// the state is not terminal it is recommended that this function panics. @@ -297,7 +300,10 @@ pub trait GeneralSum { } /// TODO -pub trait SimpleSum { +pub trait SimpleSum +where + Self: Extensive, +{ /// If `state` is terminal, returns the utility vector associated with that /// state, where `utility[i]` is the utility of the state for player `i`. If /// the state is not terminal, it is recommended that this function panics. @@ -316,7 +322,10 @@ pub trait SimpleSum { /// Since either entry determines the other, knowing one of the entries and the /// turn information for a given state provides enough information to determine /// both players' utilities. -pub trait ClassicGame { +pub trait ClassicGame +where + Self: Extensive<2>, +{ /// If `state` is terminal, returns the utility of the player whose turn it /// is at that state. If the state is not terminal, it is recommended that /// this function panics. @@ -337,7 +346,10 @@ pub trait ClassicGame { /// A draw state is one where there is no way to reach a winning state but it is /// possible to play forever without reaching a losing state. A tie state is any /// state that does not subjectively fit into any of the above categories. -pub trait ClassicPuzzle { +pub trait ClassicPuzzle +where + Self: Extensive<1>, +{ /// If `state` is terminal, returns the utility of the puzzle's player. If /// the state is not terminal, it is recommended that this function panics. fn utility(&self, state: State) -> SimpleUtility; @@ -383,3 +395,12 @@ where [ClassicPuzzle::utility(self, state)] } } + +impl Extensive<1> for G +where + G: ClassicPuzzle, +{ + fn turn(&self, state: State) -> Turn { + 0 + } +} diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index a42cb2d..e414e18 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -23,7 +23,6 @@ where G: DTransition + Bounded + ClassicPuzzle - + Extensive<1> + Game, { let mut db = volatile_database(game) @@ -40,7 +39,6 @@ where G: DTransition + Bounded + ClassicPuzzle - + Extensive<1> + Game, D: KVStore, { @@ -55,17 +53,13 @@ where SimpleUtility::LOSE => losing_queue.push_back(end_state), SimpleUtility::TIE => Err(SolverViolation { name: "PuzzleSolver".to_string(), - hint: format!( - "Primitive end position cannot have utility TIE - for a puzzle" - ), + hint: format!("Primitive end position cannot have utility TIE + for a puzzle"), })?, SimpleUtility::DRAW => Err(SolverViolation { name: "PuzzleSolver".to_string(), - hint: format!( - "Primitive end position cannot have utility DRAW - for a puzzle" - ), + hint: format!("Primitive end position cannot have utility DRAW + for a puzzle"), })?, } // Add ending state utility and remoteness to database @@ -162,7 +156,6 @@ where G: DTransition + Bounded + ClassicPuzzle - + Extensive<1> + Game, D: KVStore, { @@ -184,7 +177,6 @@ where G: DTransition + Bounded + ClassicPuzzle - + Extensive<1> + Game, D: KVStore, { @@ -230,7 +222,7 @@ where /* fn volatile_database(game: &G) -> Result where - G: Extensive + Game, + G: Extensive<1> + Game, { let id = game.id(); let db = volatile::Database::initialize(); @@ -283,9 +275,9 @@ impl KVStore for TestDB { } } -fn volatile_database(game: &G) -> Result +fn volatile_database(game: &G) -> Result where - G: Extensive + Game, + G: Extensive<1> + Game, { let db = TestDB::initialize(); Ok(db) @@ -359,12 +351,6 @@ mod tests { } } - impl Extensive<1> for PuzzleGraph { - fn turn(&self, state: State) -> Turn { - 0 - } - } - impl ClassicPuzzle for PuzzleGraph { fn utility(&self, state: State) -> SimpleUtility { self.adj_list[state as usize] From 64c29ac550ef664d39b66d8c95150b853cc56c95 Mon Sep 17 00:00:00 2001 From: Ishir Garg Date: Tue, 23 Apr 2024 00:52:23 -0700 Subject: [PATCH 13/16] Redefine utility interface relationships for simplicity --- src/game/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/game/mod.rs b/src/game/mod.rs index d9e5ccd..b9a7b88 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -324,7 +324,7 @@ where /// both players' utilities. pub trait ClassicGame where - Self: Extensive<2>, + Self: SimpleSum<2>, { /// If `state` is terminal, returns the utility of the player whose turn it /// is at that state. If the state is not terminal, it is recommended that @@ -348,7 +348,7 @@ where /// state that does not subjectively fit into any of the above categories. pub trait ClassicPuzzle where - Self: Extensive<1>, + Self: SimpleSum<1>, { /// If `state` is terminal, returns the utility of the puzzle's player. If /// the state is not terminal, it is recommended that this function panics. @@ -368,7 +368,7 @@ where impl SimpleSum<2> for G where - G: ClassicGame + Extensive<2>, + G: ClassicGame, { fn utility(&self, state: State) -> [SimpleUtility; 2] { let player_utility = ClassicGame::utility(self, state); From b9911ac93fd4ebc775c2dd9de8eccf70ce66328a Mon Sep 17 00:00:00 2001 From: Ishir Garg Date: Tue, 23 Apr 2024 04:34:19 -0700 Subject: [PATCH 14/16] Refactor cyclic and puzzle solvers --- src/model.rs | 2 +- src/solver/algorithm/strong/cyclic.rs | 20 +- src/solver/algorithm/strong/puzzle.rs | 327 +++++++++++--------------- 3 files changed, 144 insertions(+), 205 deletions(-) diff --git a/src/model.rs b/src/model.rs index 55228d2..929dc81 100644 --- a/src/model.rs +++ b/src/model.rs @@ -28,7 +28,7 @@ pub type Turn = usize; pub type Utility = i64; /// TODO -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum SimpleUtility { WIN = 0, LOSE = 1, diff --git a/src/solver/algorithm/strong/cyclic.rs b/src/solver/algorithm/strong/cyclic.rs index 0b2a07f..0cf7771 100644 --- a/src/solver/algorithm/strong/cyclic.rs +++ b/src/solver/algorithm/strong/cyclic.rs @@ -9,6 +9,7 @@ //! - Ishir Garg, 3/12/2024 (ishirgarg@berkeley.edu) use anyhow::{Context, Result}; + use std::collections::{HashMap, VecDeque}; use crate::database::volatile; @@ -20,18 +21,7 @@ use crate::model::{PlayerCount, Remoteness, State, Turn}; use crate::solver::record::sur::RecordBuffer; use crate::solver::RecordType; -/* CONSTANTS */ - -/// The exact number of bits that are used to encode remoteness. -const REMOTENESS_SIZE: usize = 16; - -/// The maximum number of bits that can be used to encode a record. -const BUFFER_SIZE: usize = 128; - -/// The exact number of bits that are used to encode utility for one player. -const UTILITY_SIZE: usize = 2; - -pub fn two_player_zero_sum_dynamic_solver( +pub fn dynamic_solver( game: &G, mode: IOMode, ) -> Result<()> @@ -40,11 +30,11 @@ where { let mut db = volatile_database(game).context("Failed to initialize database.")?; - basic_loopy_solver(game, &mut db)?; + cyclic_solver(game, &mut db)?; Ok(()) } -fn basic_loopy_solver(game: &G, db: &mut D) -> Result<()> +fn cyclic_solver(game: &G, db: &mut D) -> Result<()> where G: DTransition + Bounded + SimpleSum<2> + Extensive<2> + Game, D: KVStore, @@ -91,7 +81,7 @@ where let parents = game.retrograde(child); // If child is a losing position - if matches!(child_utility, SimpleUtility::LOSE) { + if let SimpleUtility::LOSE = child_utility { for parent in parents { if *child_counts.get(&parent).expect("Failed to enqueue parent state in initial enqueueing stage") > 0 { // Add database entry diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index e414e18..c645043 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -6,6 +6,9 @@ //! - Ishir Garg (ishirgarg@berkeley.edu) use anyhow::{Context, Result}; +use bitvec::{order::Msb0, prelude::*, slice::BitSlice, store::BitStore}; + +use std::collections::{HashMap, VecDeque}; use crate::database::volatile; use crate::database::{KVStore, Tabular}; @@ -15,8 +18,6 @@ use crate::model::SimpleUtility; use crate::model::{Remoteness, State}; use crate::solver::error::SolverError::SolverViolation; use crate::solver::record::surcc::{ChildCount, RecordBuffer}; -use bitvec::{order::Msb0, prelude::*, slice::BitSlice, store::BitStore}; -use std::collections::{HashMap, HashSet, VecDeque}; pub fn dynamic_solver(game: &G, mode: IOMode) -> Result<()> where @@ -34,6 +35,12 @@ where Ok(()) } +/// Runs BFS starting from the ending primitive positions of a game, and working +/// its way up the game tree in reverse. Assigns a remoteness and simple +/// utiliity to every winning and losing position. Draws (positions where +/// winning is impossible, but it is possible to play forever without losing) +/// not assigned a remoteness. This implementation uses the SURCC record to +/// store child count along with utility and remoteness. fn reverse_bfs_solver(db: &mut D, game: &G) -> Result<()> where G: DTransition @@ -42,7 +49,6 @@ where + Game, D: KVStore, { - // Get end states and create frontiers let end_states = discover_child_counts(db, game)?; let mut winning_queue: VecDeque = VecDeque::new(); @@ -63,18 +69,34 @@ where })?, } // Add ending state utility and remoteness to database - update_db_record(db, end_state, game.utility(end_state), 0, 0)?; + update_db_record(db, end_state, ClassicPuzzle::utility(game, end_state), 0, 0)?; } - // Perform BFS on winning states + reverse_bfs_winning_states(db, game, &mut winning_queue)?; + reverse_bfs_losing_states(db, game, &mut losing_queue)?; + + Ok(()) +} + +/// Performs BFS on winning states, marking visited states as a win +fn reverse_bfs_winning_states( + db: &mut D, + game: &G, + winning_queue: &mut VecDeque +) -> Result<()> +where + G: DTransition + + Bounded + + ClassicPuzzle + + Game, + D: KVStore, +{ while let Some(state) = winning_queue.pop_front() { let buf = RecordBuffer::from(db.get(state).unwrap())?; - let child_remoteness = - RecordBuffer::from(db.get(state).unwrap())?.get_remoteness(); + let child_remoteness = buf.get_remoteness(); for parent in game.retrograde(state) { - let child_count = - RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); + let child_count = RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); if child_count > 0 { winning_queue.push_back(parent); update_db_record( @@ -87,10 +109,24 @@ where } } } + + Ok(()) +} - // Perform BFS on losing states, where remoteness is the longest path to a - // losing primitive - // position. +/// Performs BFS on losing states, marking visited states as a loss. Remoteness +/// is the shortest path to a primitive losing position. +fn reverse_bfs_losing_states( + db: &mut D, + game: &G, + losing_queue: &mut VecDeque +)-> Result<()> +where + G: DTransition + + Bounded + + ClassicPuzzle + + Game, + D: KVStore, +{ while let Some(state) = losing_queue.pop_front() { let parents = game.retrograde(state); let child_remoteness = @@ -122,7 +158,7 @@ where } } } - + Ok(()) } @@ -160,14 +196,12 @@ where D: KVStore, { let mut end_states = Vec::new(); - discover_child_counts_helper(db, game, game.start(), &mut end_states)?; + discover_child_counts_from_state(db, game, game.start(), &mut end_states)?; Ok(end_states) } -/// Adds child counts for each position to the database -/// Also returns a vector of all primitive positions -fn discover_child_counts_helper( +fn discover_child_counts_from_state( db: &mut D, game: &G, state: State, @@ -197,16 +231,16 @@ where db.put(state, &buf); // We need to check both prograde and retrograde; consider a game with 3 - // nodes where 0-->2 and 1-->2. Then, starting from node 0 with only - // progrades would discover states 0 and 1; we need to include retrogrades - // to discover state 2. + // nodes where the edges are `0` → `2` and `1` → `2`. Then, starting from + // node 0 with only progrades would discover states 0 and 1; we need to + // include retrogrades to discover state 2. for &child in game .prograde(state) .iter() .chain(game.retrograde(state).iter()) { if db.get(child).is_none() { - discover_child_counts_helper(db, game, child, end_states)?; + discover_child_counts_from_state(db, game, child, end_states)?; } } @@ -227,7 +261,7 @@ where let id = game.id(); let db = volatile::Database::initialize(); - let schema = RecordType::SUR(1) + let schema = RecordType::SURCC(1) .try_into() .context("Failed to create table schema for solver records.")?; db.create_table(&id, schema) @@ -240,7 +274,6 @@ where } */ -// THIS IS ONLY FOR TESTING PURPOSES struct TestDB { memory: HashMap>, } @@ -271,7 +304,7 @@ impl KVStore for TestDB { } fn del(&mut self, key: State) { - unimplemented![]; + unimplemented!(); } } @@ -285,6 +318,9 @@ where #[cfg(test)] mod tests { + use anyhow::Result; + + use crate::game::mock::{Session, SessionBuilder}; use crate::database::{KVStore, Tabular}; use crate::game::mock; use crate::game::{ @@ -296,11 +332,9 @@ mod tests { use crate::model::{State, Turn}; use crate::node; use crate::solver::record::surcc::RecordBuffer; - use anyhow::Result; - use std::collections::{HashMap, VecDeque}; use super::{ - discover_child_counts, reverse_bfs_solver, volatile_database, TestDB, + reverse_bfs_solver, volatile_database, }; struct GameNode { @@ -390,10 +424,10 @@ mod tests { let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - assert!(matches!( + assert_eq!( RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN - )); + ); assert_eq!( RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 0 @@ -421,14 +455,14 @@ mod tests { let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - assert!(matches!( + assert_eq!( RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN - )); - assert!(matches!( + ); + assert_eq!( RecordBuffer::from(db.get(1).unwrap())?.get_utility(0)?, SimpleUtility::WIN - )); + ); assert_eq!( RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), @@ -480,26 +514,14 @@ mod tests { )); } - assert_eq!( - RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), - 2 - ); - assert_eq!( - RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), - 0 - ); + let expected_remoteness = [1, 2, 1, 1, 0]; + + for (i, &remoteness) in expected_remoteness.iter().enumerate() { + assert_eq!( + RecordBuffer::from(db.get(i as u64).unwrap())?.get_remoteness(), + remoteness + ) + } Ok(()) } @@ -540,10 +562,10 @@ mod tests { reverse_bfs_solver(&mut db, &graph); for i in 0..5 { - assert!(matches!( + assert_eq!( RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::DRAW - )); + ); } Ok(()) @@ -596,61 +618,33 @@ mod tests { let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - for i in 0..=5 { - assert!(matches!( - RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, - SimpleUtility::WIN - )); - } - assert!(matches!( - RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, - SimpleUtility::LOSE - )); - assert!(matches!( - RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, - SimpleUtility::LOSE - )); - assert!(matches!( - RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, + let expected_utilities = [ + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::LOSE, + SimpleUtility::LOSE, SimpleUtility::LOSE - )); + ]; - assert_eq!( - RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), - 2 - ); - assert_eq!( - RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), - 4 - ); - assert_eq!( - RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), - 4 - ); - assert_eq!( - RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), - 3 - ); - assert_eq!( - RecordBuffer::from(db.get(5).unwrap())?.get_remoteness(), - 0 - ); - assert_eq!( - RecordBuffer::from(db.get(6).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(7).unwrap())?.get_remoteness(), - 2 - ); - assert_eq!( - RecordBuffer::from(db.get(8).unwrap())?.get_remoteness(), - 0 - ); + let expected_remoteness = [2, 4, 4, 1, 3, 0, 1, 2, 0]; + + for (i, &utility) in expected_utilities.iter().enumerate() { + assert_eq!( + RecordBuffer::from(db.get(i as u64).unwrap())?.get_utility(0)?, + utility + ); + } + + for (i, &remoteness) in expected_remoteness.iter().enumerate() { + assert_eq!( + RecordBuffer::from(db.get(i as u64).unwrap())?.get_remoteness(), + remoteness + ); + } Ok(()) } @@ -688,8 +682,8 @@ mod tests { children: vec![8], }, GameNode { - utility: None, - children: vec![6, 8, 13], + utility: Some(SimpleUtility::LOSE), + children: vec![9, 2], }, GameNode { utility: Some(SimpleUtility::LOSE), @@ -701,15 +695,15 @@ mod tests { }, GameNode { utility: Some(SimpleUtility::LOSE), - children: vec![11], + children: vec![7], }, GameNode { - utility: Some(SimpleUtility::LOSE), - children: vec![9, 2], + utility: None, + children: vec![6, 8, 13], }, GameNode { utility: Some(SimpleUtility::LOSE), - children: vec![7], + children: vec![11], }, GameNode { utility: Some(SimpleUtility::LOSE), @@ -722,83 +716,38 @@ mod tests { let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - for i in 0..=5 { - assert!(matches!( - RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, - SimpleUtility::WIN - )); + let expected_utilities = [ + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::LOSE, + SimpleUtility::WIN, + SimpleUtility::LOSE, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::DRAW, + SimpleUtility::DRAW, + SimpleUtility::DRAW, + ]; + + let expected_remoteness = [2, 1, 4, 1, 3, 0, 1, 5, 0, 7, 6]; + + for (i, &utility) in expected_utilities.iter().enumerate() { + assert_eq!( + RecordBuffer::from(db.get(i as u64).unwrap())?.get_utility(0)?, + utility + ); } - assert!(matches!( - RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, - SimpleUtility::LOSE - )); - assert!(matches!( - RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, - SimpleUtility::DRAW - )); - assert!(matches!( - RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, - SimpleUtility::LOSE - )); - for i in 9..=11 { - assert!(matches!( - RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, - SimpleUtility::WIN - )); - } - assert!(matches!( - RecordBuffer::from(db.get(12).unwrap())?.get_utility(0)?, - SimpleUtility::DRAW - )); - assert!(matches!( - RecordBuffer::from(db.get(13).unwrap())?.get_utility(0)?, - SimpleUtility::DRAW - )); - assert_eq!( - RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), - 2 - ); - assert_eq!( - RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), - 4 - ); - assert_eq!( - RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), - 3 - ); - assert_eq!( - RecordBuffer::from(db.get(5).unwrap())?.get_remoteness(), - 0 - ); - assert_eq!( - RecordBuffer::from(db.get(6).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(8).unwrap())?.get_remoteness(), - 0 - ); - assert_eq!( - RecordBuffer::from(db.get(9).unwrap())?.get_remoteness(), - 7 - ); - assert_eq!( - RecordBuffer::from(db.get(10).unwrap())?.get_remoteness(), - 6 - ); - assert_eq!( - RecordBuffer::from(db.get(11).unwrap())?.get_remoteness(), - 5 - ); + for (i, &remoteness) in expected_remoteness.iter().enumerate() { + assert_eq!( + RecordBuffer::from(db.get(i as u64).unwrap())?.get_remoteness(), + remoteness + ); + } Ok(()) } From 2afef5a106fc77d7d058d82c826083cafaeb194d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 25 Apr 2024 02:16:15 +0000 Subject: [PATCH 15/16] Format Rust code using rustfmt --- src/game/mod.rs | 4 +- src/solver/algorithm/strong/cyclic.rs | 5 +- src/solver/algorithm/strong/puzzle.rs | 106 ++++++++++++++------------ 3 files changed, 60 insertions(+), 55 deletions(-) diff --git a/src/game/mod.rs b/src/game/mod.rs index b9a7b88..359bf22 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -289,7 +289,7 @@ where /* UTILITY INTERFACES */ /// TODO -pub trait GeneralSum +pub trait GeneralSum where Self: Extensive, { @@ -397,7 +397,7 @@ where } impl Extensive<1> for G -where +where G: ClassicPuzzle, { fn turn(&self, state: State) -> Turn { diff --git a/src/solver/algorithm/strong/cyclic.rs b/src/solver/algorithm/strong/cyclic.rs index 0cf7771..2362516 100644 --- a/src/solver/algorithm/strong/cyclic.rs +++ b/src/solver/algorithm/strong/cyclic.rs @@ -21,10 +21,7 @@ use crate::model::{PlayerCount, Remoteness, State, Turn}; use crate::solver::record::sur::RecordBuffer; use crate::solver::RecordType; -pub fn dynamic_solver( - game: &G, - mode: IOMode, -) -> Result<()> +pub fn dynamic_solver(game: &G, mode: IOMode) -> Result<()> where G: DTransition + Bounded + SimpleSum<2> + Extensive<2> + Game, { diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index c645043..d0cdf35 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -21,10 +21,7 @@ use crate::solver::record::surcc::{ChildCount, RecordBuffer}; pub fn dynamic_solver(game: &G, mode: IOMode) -> Result<()> where - G: DTransition - + Bounded - + ClassicPuzzle - + Game, + G: DTransition + Bounded + ClassicPuzzle + Game, { let mut db = volatile_database(game) .context("Failed to initialize volatile database.")?; @@ -43,10 +40,7 @@ where /// store child count along with utility and remoteness. fn reverse_bfs_solver(db: &mut D, game: &G) -> Result<()> where - G: DTransition - + Bounded - + ClassicPuzzle - + Game, + G: DTransition + Bounded + ClassicPuzzle + Game, D: KVStore, { let end_states = discover_child_counts(db, game)?; @@ -59,17 +53,27 @@ where SimpleUtility::LOSE => losing_queue.push_back(end_state), SimpleUtility::TIE => Err(SolverViolation { name: "PuzzleSolver".to_string(), - hint: format!("Primitive end position cannot have utility TIE - for a puzzle"), + hint: format!( + "Primitive end position cannot have utility TIE + for a puzzle" + ), })?, SimpleUtility::DRAW => Err(SolverViolation { name: "PuzzleSolver".to_string(), - hint: format!("Primitive end position cannot have utility DRAW - for a puzzle"), + hint: format!( + "Primitive end position cannot have utility DRAW + for a puzzle" + ), })?, } // Add ending state utility and remoteness to database - update_db_record(db, end_state, ClassicPuzzle::utility(game, end_state), 0, 0)?; + update_db_record( + db, + end_state, + ClassicPuzzle::utility(game, end_state), + 0, + 0, + )?; } reverse_bfs_winning_states(db, game, &mut winning_queue)?; @@ -82,13 +86,10 @@ where fn reverse_bfs_winning_states( db: &mut D, game: &G, - winning_queue: &mut VecDeque + winning_queue: &mut VecDeque, ) -> Result<()> where - G: DTransition - + Bounded - + ClassicPuzzle - + Game, + G: DTransition + Bounded + ClassicPuzzle + Game, D: KVStore, { while let Some(state) = winning_queue.pop_front() { @@ -96,7 +97,8 @@ where let child_remoteness = buf.get_remoteness(); for parent in game.retrograde(state) { - let child_count = RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); + let child_count = + RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); if child_count > 0 { winning_queue.push_back(parent); update_db_record( @@ -109,7 +111,7 @@ where } } } - + Ok(()) } @@ -118,13 +120,10 @@ where fn reverse_bfs_losing_states( db: &mut D, game: &G, - losing_queue: &mut VecDeque -)-> Result<()> + losing_queue: &mut VecDeque, +) -> Result<()> where - G: DTransition - + Bounded - + ClassicPuzzle - + Game, + G: DTransition + Bounded + ClassicPuzzle + Game, D: KVStore, { while let Some(state) = losing_queue.pop_front() { @@ -158,7 +157,7 @@ where } } } - + Ok(()) } @@ -189,10 +188,7 @@ where fn discover_child_counts(db: &mut D, game: &G) -> Result> where - G: DTransition - + Bounded - + ClassicPuzzle - + Game, + G: DTransition + Bounded + ClassicPuzzle + Game, D: KVStore, { let mut end_states = Vec::new(); @@ -208,10 +204,7 @@ fn discover_child_counts_from_state( end_states: &mut Vec, ) -> Result<()> where - G: DTransition - + Bounded - + ClassicPuzzle - + Game, + G: DTransition + Bounded + ClassicPuzzle + Game, D: KVStore, { let child_count = game.prograde(state).len() as ChildCount; @@ -320,9 +313,9 @@ where mod tests { use anyhow::Result; - use crate::game::mock::{Session, SessionBuilder}; use crate::database::{KVStore, Tabular}; use crate::game::mock; + use crate::game::mock::{Session, SessionBuilder}; use crate::game::{ Bounded, ClassicPuzzle, DTransition, Extensive, Game, GameData, SimpleSum, @@ -333,9 +326,7 @@ mod tests { use crate::node; use crate::solver::record::surcc::RecordBuffer; - use super::{ - reverse_bfs_solver, volatile_database, - }; + use super::{reverse_bfs_solver, volatile_database}; struct GameNode { children: Vec, @@ -516,7 +507,10 @@ mod tests { let expected_remoteness = [1, 2, 1, 1, 0]; - for (i, &remoteness) in expected_remoteness.iter().enumerate() { + for (i, &remoteness) in expected_remoteness + .iter() + .enumerate() + { assert_eq!( RecordBuffer::from(db.get(i as u64).unwrap())?.get_remoteness(), remoteness @@ -627,19 +621,26 @@ mod tests { SimpleUtility::WIN, SimpleUtility::LOSE, SimpleUtility::LOSE, - SimpleUtility::LOSE - ]; + SimpleUtility::LOSE, + ]; let expected_remoteness = [2, 4, 4, 1, 3, 0, 1, 2, 0]; - for (i, &utility) in expected_utilities.iter().enumerate() { + for (i, &utility) in expected_utilities + .iter() + .enumerate() + { assert_eq!( - RecordBuffer::from(db.get(i as u64).unwrap())?.get_utility(0)?, + RecordBuffer::from(db.get(i as u64).unwrap())? + .get_utility(0)?, utility ); } - for (i, &remoteness) in expected_remoteness.iter().enumerate() { + for (i, &remoteness) in expected_remoteness + .iter() + .enumerate() + { assert_eq!( RecordBuffer::from(db.get(i as u64).unwrap())?.get_remoteness(), remoteness @@ -731,18 +732,25 @@ mod tests { SimpleUtility::DRAW, SimpleUtility::DRAW, SimpleUtility::DRAW, - ]; + ]; let expected_remoteness = [2, 1, 4, 1, 3, 0, 1, 5, 0, 7, 6]; - for (i, &utility) in expected_utilities.iter().enumerate() { + for (i, &utility) in expected_utilities + .iter() + .enumerate() + { assert_eq!( - RecordBuffer::from(db.get(i as u64).unwrap())?.get_utility(0)?, + RecordBuffer::from(db.get(i as u64).unwrap())? + .get_utility(0)?, utility ); } - for (i, &remoteness) in expected_remoteness.iter().enumerate() { + for (i, &remoteness) in expected_remoteness + .iter() + .enumerate() + { assert_eq!( RecordBuffer::from(db.get(i as u64).unwrap())?.get_remoteness(), remoteness From 6bdaa8a8f51b58e5223b94fbbcfbfc493a071904 Mon Sep 17 00:00:00 2001 From: Max Fierro Date: Tue, 14 May 2024 23:21:50 -0700 Subject: [PATCH 16/16] Deleted repeated tests, cleaned up a bit --- src/solver/record/surcc.rs | 162 +------------------------------------ 1 file changed, 1 insertion(+), 161 deletions(-) diff --git a/src/solver/record/surcc.rs b/src/solver/record/surcc.rs index 3ffba4b..d446109 100644 --- a/src/solver/record/surcc.rs +++ b/src/solver/record/surcc.rs @@ -375,166 +375,6 @@ impl RecordBuffer { mod tests { use super::*; - // The maximum and minimum numeric values that can be represented with - // exactly UTILITY_SIZE bits in two's complement. - // - // Example if UTILITY_SIZE is 8: - // - // * `MAX_UTILITY = 0b01111111 = 127 = 2^(8 - 1) - 1` - // * `MIN_UTILITY = 0b10000000 = -128 = -127 - 1` - // - // Useful: https://www.omnicalculator.com/math/twos-complement - const MAX_UTILITY: SUtility = SUtility::Tie; - const MIN_UTILITY: SUtility = SUtility::Win; - - // The maximum numeric remoteness value that can be expressed with exactly - // REMOTENESS_SIZE bits in an unsigned integer. - const MAX_REMOTENESS: Remoteness = 2_u64.pow(REMOTENESS_SIZE as u32) - 1; - const MAX_CHILD_COUNT: ChildCount = 2_u64.pow(CHILD_COUNT_SIZE as u32) - 1; - - #[test] - fn initialize_with_valid_player_count() { - for i in 0..=RecordBuffer::player_count(BUFFER_SIZE) { - assert!(RecordBuffer::new(i).is_ok()) - } - } - - #[test] - fn initialize_with_invalid_player_count() { - let max = RecordBuffer::player_count(BUFFER_SIZE); - - assert!(RecordBuffer::new(max + 1).is_err()); - assert!(RecordBuffer::new(max + 10).is_err()); - assert!(RecordBuffer::new(max + 100).is_err()); - } - - #[test] - fn initialize_from_valid_buffer() { - let buf = bitarr!(u8, Msb0; 0; BUFFER_SIZE); - for i in REMOTENESS_SIZE..BUFFER_SIZE { - assert!(RecordBuffer::from(&buf[0..i]).is_ok()); - } - } - - #[test] - fn initialize_from_invalid_buffer() { - let buf1 = bitarr!(u8, Msb0; 0; BUFFER_SIZE + 1); - let buf2 = bitarr!(u8, Msb0; 0; BUFFER_SIZE + 10); - let buf3 = bitarr!(u8, Msb0; 0; BUFFER_SIZE + 100); - - assert!(RecordBuffer::from(&buf1).is_err()); - assert!(RecordBuffer::from(&buf2).is_err()); - assert!(RecordBuffer::from(&buf3).is_err()); - } - - #[test] - fn set_record_attributes() { - let mut r1 = RecordBuffer::new(7).unwrap(); - let mut r2 = RecordBuffer::new(4).unwrap(); - let mut r3 = RecordBuffer::new(0).unwrap(); - - let v1 = [SUtility::WIN; 7]; - let v2 = [SUtility::TIE; 4]; - let v3: [SUtility; 0] = []; - - let v4 = [MAX_UTILITY; 7]; - let v5 = [MIN_UTILITY; 4]; - let v6 = [SUtility::DRAW]; - - let good = Remoteness::MIN; - let bad = Remoteness::MAX; - - assert!(r1.set_utility(v1).is_ok()); - assert!(r2.set_utility(v2).is_ok()); - assert!(r3.set_utility(v3).is_ok()); - assert!(r1.set_utility(v4).is_ok()); - assert!(r2.set_utility(v5).is_ok()); - assert!(r3.set_utility(v6).is_err()); - - assert!(r1.set_remoteness(good).is_ok()); - assert!(r2.set_remoteness(good).is_ok()); - assert!(r3.set_remoteness(good).is_ok()); - - assert!(r1.set_remoteness(bad).is_err()); - assert!(r2.set_remoteness(bad).is_err()); - assert!(r3.set_remoteness(bad).is_err()); - } - - #[test] - fn data_is_valid_after_round_trip() { - let mut record = RecordBuffer::new(5).unwrap(); - let payoffs = [ - SUtility::LOSE, - SUtility::WIN, - SUtility::LOSE, - SUtility::LOSE, - SUtility::LOSE, - ]; - let remoteness = 790; - - record - .set_utility(payoffs) - .unwrap(); - - record - .set_remoteness(remoteness) - .unwrap(); - - // Utilities unchanged after insert and fetch - for i in 0..5 { - let fetched_utility = record.get_utility(i).unwrap(); - let actual_utility = payoffs[i]; - assert!(matches!(fetched_utility, actual_utility)); - } - - // Remoteness unchanged after insert and fetch - let fetched_remoteness = record.get_remoteness(); - let actual_remoteness = remoteness; - assert_eq!(fetched_remoteness, actual_remoteness); - - // Fetching utility entries of invalid players - assert!(record.get_utility(5).is_err()); - assert!(record.get_utility(10).is_err()); - } - - #[test] - fn extreme_data_is_valid_after_round_trip() { - let mut record = RecordBuffer::new(6).unwrap(); - - let good = [ - SUtility::WIN, - SUtility::LOSE, - SUtility::TIE, - SUtility::TIE, - SUtility::DRAW, - SUtility::WIN, - ]; - - let bad = [SUtility::DRAW, SUtility::WIN, SUtility::TIE]; - - assert!(record.set_utility(good).is_ok()); - assert!(record - .set_remoteness(MAX_REMOTENESS) - .is_ok()); - - for i in 0..6 { - let fetched_utility = record.get_utility(i).unwrap(); - let actual_utility = good[i]; - assert!(matches!(fetched_utility, actual_utility)); - } - - assert_eq!(record.get_remoteness(), MAX_REMOTENESS); - assert!(record.set_utility(bad).is_err()); - } - - #[test] - fn child_counts_retrieved_properly() -> Result<()> { - let mut buf = RecordBuffer::new(3)?; - buf.set_child_count(4)?; - - assert_eq!(buf.get_child_count(), 4); - - Ok(()) - } + // TODO }